/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.testutils.CheckedThread;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

public class InputGateConcurrentTest {
    @Test
    public void testConsumptionWithLocalChannels() throws Exception {
        int numChannels = 11;
        int buffersPerChannel = 1000;
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        PipelinedSubpartition[] partitions = new PipelinedSubpartition[11];
        Source[] sources = new Source[11];
        ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])partitions);
        SingleInputGate gate = new SingleInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 11, (TaskActions)Mockito.mock(TaskActions.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (int i = 0; i < 11; ++i) {
            LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, (TaskEventDispatcher)Mockito.mock(TaskEventDispatcher.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
            partitions[i] = new PipelinedSubpartition(0, resultPartition);
            sources[i] = new PipelinedSubpartitionSource(partitions[i]);
        }
        ProducerThread producer = new ProducerThread(sources, 11000, 4, 10);
        ConsumerThread consumer = new ConsumerThread(gate, 11000);
        producer.start();
        consumer.start();
        producer.sync();
        consumer.sync();
    }

    @Test
    public void testConsumptionWithRemoteChannels() throws Exception {
        int numChannels = 11;
        int buffersPerChannel = 1000;
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        Source[] sources = new Source[11];
        SingleInputGate gate = new SingleInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 11, (TaskActions)Mockito.mock(TaskActions.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (int i = 0; i < 11; ++i) {
            RemoteInputChannel channel = new RemoteInputChannel(gate, i, new ResultPartitionID(), (ConnectionID)Mockito.mock(ConnectionID.class), connManager, 0, 0, (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
            sources[i] = new RemoteChannelSource(channel);
        }
        ProducerThread producer = new ProducerThread(sources, 11000, 4, 10);
        ConsumerThread consumer = new ConsumerThread(gate, 11000);
        producer.start();
        consumer.start();
        producer.sync();
        consumer.sync();
    }

    @Test
    public void testConsumptionWithMixedChannels() throws Exception {
        int numChannels = 61;
        int numLocalChannels = 20;
        int buffersPerChannel = 1000;
        ArrayList<Boolean> localOrRemote = new ArrayList<Boolean>(61);
        for (int i = 0; i < 61; ++i) {
            localOrRemote.add(i < 20);
        }
        Collections.shuffle(localOrRemote);
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[20];
        ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])localPartitions);
        Source[] sources = new Source[61];
        SingleInputGate gate = new SingleInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 61, (TaskActions)Mockito.mock(TaskActions.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        int local = 0;
        for (int i = 0; i < 61; ++i) {
            if (((Boolean)localOrRemote.get(i)).booleanValue()) {
                PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition);
                localPartitions[local++] = psp;
                sources[i] = new PipelinedSubpartitionSource(psp);
                LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, (TaskEventDispatcher)Mockito.mock(TaskEventDispatcher.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
                gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
                continue;
            }
            RemoteInputChannel channel = new RemoteInputChannel(gate, i, new ResultPartitionID(), (ConnectionID)Mockito.mock(ConnectionID.class), connManager, 0, 0, (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
            sources[i] = new RemoteChannelSource(channel);
        }
        ProducerThread producer = new ProducerThread(sources, 61000, 4, 10);
        ConsumerThread consumer = new ConsumerThread(gate, 61000);
        producer.start();
        consumer.start();
        producer.sync();
        consumer.sync();
    }

    private static class ConsumerThread
    extends CheckedThread {
        private final SingleInputGate gate;
        private final int numBuffers;

        ConsumerThread(SingleInputGate gate, int numBuffers) {
            this.gate = gate;
            this.numBuffers = numBuffers;
        }

        public void go() throws Exception {
            for (int i = this.numBuffers; i > 0; --i) {
                Assert.assertNotNull((Object)this.gate.getNextBufferOrEvent());
            }
        }
    }

    private static class ProducerThread
    extends CheckedThread {
        private final Random rnd = new Random();
        private final Source[] sources;
        private final int numTotal;
        private final int maxChunk;
        private final int yieldAfter;

        ProducerThread(Source[] sources, int numTotal, int maxChunk, int yieldAfter) {
            this.sources = sources;
            this.numTotal = numTotal;
            this.maxChunk = maxChunk;
            this.yieldAfter = yieldAfter;
        }

        public void go() throws Exception {
            Buffer buffer = InputChannelTestUtils.createMockBuffer(100);
            int nextYield = this.numTotal - this.yieldAfter;
            int i = this.numTotal;
            while (i > 0) {
                int nextChannel = this.rnd.nextInt(this.sources.length);
                int chunk = Math.min(i, this.rnd.nextInt(this.maxChunk) + 1);
                Source next = this.sources[nextChannel];
                for (int k = chunk; k > 0; --k) {
                    next.addBuffer(buffer);
                }
                if ((i -= chunk) > nextYield) continue;
                nextYield -= this.yieldAfter;
                Thread.yield();
            }
        }
    }

    private static class RemoteChannelSource
    extends Source {
        final RemoteInputChannel channel;
        private int seq = 0;

        RemoteChannelSource(RemoteInputChannel channel) {
            this.channel = channel;
        }

        @Override
        void addBuffer(Buffer buffer) throws Exception {
            this.channel.onBuffer(buffer, this.seq++);
        }
    }

    private static class PipelinedSubpartitionSource
    extends Source {
        final PipelinedSubpartition partition;

        PipelinedSubpartitionSource(PipelinedSubpartition partition) {
            this.partition = partition;
        }

        @Override
        void addBuffer(Buffer buffer) throws Exception {
            this.partition.add(buffer);
        }
    }

    private static abstract class Source {
        private Source() {
        }

        abstract void addBuffer(Buffer var1) throws Exception;
    }
}

