/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.lang.reflect.Field;
import java.net.InetAddress;
import java.util.Iterator;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.blob.BlobWriter;
import org.apache.flink.runtime.blob.VoidBlobWriter;
import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.clusterframework.types.ResourceID;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphBuilder;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
import org.apache.flink.runtime.executiongraph.restart.RestartStrategy;
import org.apache.flink.runtime.instance.SimpleSlot;
import org.apache.flink.runtime.instance.SimpleSlotContext;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
import org.apache.flink.runtime.jobmaster.LogicalSlot;
import org.apache.flink.runtime.jobmaster.SlotContext;
import org.apache.flink.runtime.jobmaster.SlotOwner;
import org.apache.flink.runtime.jobmaster.slotpool.SlotProvider;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.runtime.testingUtils.TestingUtils;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;

public class ExecutionVertexLocalityTest
extends TestLogger {
    private final JobID jobId = new JobID();
    private final JobVertexID sourceVertexId = new JobVertexID();
    private final JobVertexID targetVertexId = new JobVertexID();

    @Test
    public void testLocalityInputBasedForward() throws Exception {
        int i;
        int parallelism = 10;
        TaskManagerLocation[] locations = new TaskManagerLocation[10];
        ExecutionGraph graph = this.createTestGraph(10, false);
        for (i = 0; i < 10; ++i) {
            TaskManagerLocation location;
            ExecutionVertex source = ((ExecutionJobVertex)graph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i];
            locations[i] = location = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i);
            this.initializeLocation(source, location);
        }
        for (i = 0; i < 10; ++i) {
            ExecutionVertex target = ((ExecutionJobVertex)graph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i];
            Iterator preference = target.getPreferredLocations().iterator();
            Assert.assertTrue((boolean)preference.hasNext());
            Assert.assertEquals((Object)locations[i], ((CompletableFuture)preference.next()).get());
            Assert.assertFalse((boolean)preference.hasNext());
        }
    }

    @Test
    public void testNoLocalityInputLargeAllToAll() throws Exception {
        int i;
        int parallelism = 100;
        ExecutionGraph graph = this.createTestGraph(100, true);
        for (i = 0; i < 100; ++i) {
            ExecutionVertex source = ((ExecutionJobVertex)graph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i];
            TaskManagerLocation location = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i);
            this.initializeLocation(source, location);
        }
        for (i = 0; i < 100; ++i) {
            ExecutionVertex target = ((ExecutionJobVertex)graph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i];
            Iterator preference = target.getPreferredLocations().iterator();
            Assert.assertFalse((boolean)preference.hasNext());
        }
    }

    @Test
    public void testLocalityBasedOnState() throws Exception {
        int i;
        ExecutionVertex source;
        int parallelism = 10;
        TaskManagerLocation[] locations = new TaskManagerLocation[10];
        ExecutionGraph graph = this.createTestGraph(10, false);
        for (int i2 = 0; i2 < 10; ++i2) {
            TaskManagerLocation location;
            source = ((ExecutionJobVertex)graph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i2];
            ExecutionVertex target = ((ExecutionJobVertex)graph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i2];
            TaskManagerLocation randomLocation = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i2);
            locations[i2] = location = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 20000 + i2);
            this.initializeLocation(source, randomLocation);
            this.initializeLocation(target, location);
            this.setState(source.getCurrentExecutionAttempt(), ExecutionState.CANCELED);
            this.setState(target.getCurrentExecutionAttempt(), ExecutionState.CANCELED);
        }
        for (ExecutionJobVertex ejv : graph.getVerticesTopologically()) {
            ejv.resetForNewExecution(System.currentTimeMillis(), graph.getGlobalModVersion());
        }
        for (i = 0; i < 10; ++i) {
            source = ((ExecutionJobVertex)graph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i];
            TaskManagerLocation randomLocation = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 30000 + i);
            this.initializeLocation(source, randomLocation);
            ExecutionVertex target = ((ExecutionJobVertex)graph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i];
            target.getCurrentExecutionAttempt().setInitialState((JobManagerTaskRestore)Mockito.mock(JobManagerTaskRestore.class));
        }
        for (i = 0; i < 10; ++i) {
            ExecutionVertex target = ((ExecutionJobVertex)graph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i];
            Iterator preference = target.getPreferredLocations().iterator();
            Assert.assertTrue((boolean)preference.hasNext());
            Assert.assertEquals((Object)locations[i], ((CompletableFuture)preference.next()).get());
            Assert.assertFalse((boolean)preference.hasNext());
        }
    }

    private ExecutionGraph createTestGraph(int parallelism, boolean allToAll) throws Exception {
        JobVertex source = new JobVertex("source", this.sourceVertexId);
        source.setParallelism(parallelism);
        source.setInvokableClass(NoOpInvokable.class);
        JobVertex target = new JobVertex("source", this.targetVertexId);
        target.setParallelism(parallelism);
        target.setInvokableClass(NoOpInvokable.class);
        DistributionPattern connectionPattern = allToAll ? DistributionPattern.ALL_TO_ALL : DistributionPattern.POINTWISE;
        target.connectNewDataSetAsInput(source, connectionPattern, ResultPartitionType.PIPELINED);
        JobGraph testJob = new JobGraph(this.jobId, "test job", new JobVertex[]{source, target});
        Time timeout = Time.seconds((long)10L);
        return ExecutionGraphBuilder.buildGraph(null, (JobGraph)testJob, (Configuration)new Configuration(), (ScheduledExecutorService)TestingUtils.defaultExecutor(), (Executor)TestingUtils.defaultExecutor(), (SlotProvider)((SlotProvider)Mockito.mock(SlotProvider.class)), (ClassLoader)((Object)((Object)this)).getClass().getClassLoader(), (CheckpointRecoveryFactory)new StandaloneCheckpointRecoveryFactory(), (Time)timeout, (RestartStrategy)new FixedDelayRestartStrategy(10, 0L), (MetricGroup)new UnregisteredMetricsGroup(), (int)1, (BlobWriter)VoidBlobWriter.getInstance(), (Time)timeout, (Logger)this.log);
    }

    private void initializeLocation(ExecutionVertex vertex, TaskManagerLocation location) throws Exception {
        SimpleSlotContext slot = new SimpleSlotContext(new AllocationID(), location, 0, (TaskManagerGateway)Mockito.mock(TaskManagerGateway.class));
        SimpleSlot simpleSlot = new SimpleSlot((SlotContext)slot, (SlotOwner)Mockito.mock(SlotOwner.class), 0);
        if (!vertex.getCurrentExecutionAttempt().tryAssignResource((LogicalSlot)simpleSlot)) {
            throw new FlinkException("Could not assign resource.");
        }
    }

    private void setState(Execution execution, ExecutionState state) throws Exception {
        Field stateField = Execution.class.getDeclaredField("state");
        stateField.setAccessible(true);
        stateField.set(execution, state);
    }
}

