/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.partition;

import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

public class InputGateFairnessTest {
    @Test
    public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
        int i;
        int numChannels = 37;
        int buffersPerChannel = 27;
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        Buffer mockBuffer = InputChannelTestUtils.createMockBuffer(42);
        PipelinedSubpartition[] sources = new PipelinedSubpartition[37];
        for (int i2 = 0; i2 < 37; ++i2) {
            PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition);
            for (int p = 0; p < 27; ++p) {
                partition.add(mockBuffer);
            }
            partition.finish();
            sources[i2] = partition;
        }
        ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])sources);
        FairnessVerifyingInputGate gate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions)Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (i = 0; i < 37; ++i) {
            LocalInputChannel channel = new LocalInputChannel((SingleInputGate)gate, i, new ResultPartitionID(), resultPartitionManager, (TaskEventDispatcher)Mockito.mock(TaskEventDispatcher.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
        }
        for (i = 1036; i > 0; --i) {
            Assert.assertNotNull((Object)gate.getNextBufferOrEvent());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (PipelinedSubpartition source : sources) {
                int size = source.getCurrentNumberOfBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
        }
        Assert.assertNull((Object)gate.getNextBufferOrEvent());
    }

    @Test
    public void testFairConsumptionLocalChannels() throws Exception {
        int i;
        int numChannels = 37;
        int buffersPerChannel = 27;
        ResultPartition resultPartition = (ResultPartition)Mockito.mock(ResultPartition.class);
        Buffer mockBuffer = InputChannelTestUtils.createMockBuffer(42);
        PipelinedSubpartition[] sources = new PipelinedSubpartition[37];
        for (int i2 = 0; i2 < 37; ++i2) {
            sources[i2] = new PipelinedSubpartition(0, resultPartition);
        }
        ResultPartitionManager resultPartitionManager = InputChannelTestUtils.createResultPartitionManager((ResultSubpartition[])sources);
        FairnessVerifyingInputGate gate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions)Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        for (i = 0; i < 37; ++i) {
            LocalInputChannel channel = new LocalInputChannel((SingleInputGate)gate, i, new ResultPartitionID(), resultPartitionManager, (TaskEventDispatcher)Mockito.mock(TaskEventDispatcher.class), (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
        }
        sources[12].add(mockBuffer);
        for (i = 0; i < 999; ++i) {
            Assert.assertNotNull((Object)gate.getNextBufferOrEvent());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (PipelinedSubpartition source : sources) {
                int size = source.getCurrentNumberOfBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
            if (i % 74 != 0) continue;
            this.fillRandom(sources, 3, mockBuffer);
        }
    }

    @Test
    public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
        int i;
        int numChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = InputChannelTestUtils.createMockBuffer(42);
        FairnessVerifyingInputGate gate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions)Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = new RemoteInputChannel((SingleInputGate)gate, i, new ResultPartitionID(), (ConnectionID)Mockito.mock(ConnectionID.class), connManager, 0, 0, (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            for (int p = 0; p < 27; ++p) {
                channel.onBuffer(mockBuffer, p);
            }
            channel.onBuffer(EventSerializer.toBuffer((AbstractEvent)EndOfPartitionEvent.INSTANCE), 27);
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
        }
        for (i = 1036; i > 0; --i) {
            Assert.assertNotNull((Object)gate.getNextBufferOrEvent());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
        }
        Assert.assertNull((Object)gate.getNextBufferOrEvent());
    }

    @Test
    public void testFairConsumptionRemoteChannels() throws Exception {
        int i;
        int numChannels = 37;
        int buffersPerChannel = 27;
        Buffer mockBuffer = InputChannelTestUtils.createMockBuffer(42);
        FairnessVerifyingInputGate gate = new FairnessVerifyingInputGate("Test Task Name", new JobID(), new IntermediateDataSetID(), 0, 37, (TaskActions)Mockito.mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
        ConnectionManager connManager = InputChannelTestUtils.createDummyConnectionManager();
        RemoteInputChannel[] channels = new RemoteInputChannel[37];
        int[] channelSequenceNums = new int[37];
        for (i = 0; i < 37; ++i) {
            RemoteInputChannel channel;
            channels[i] = channel = new RemoteInputChannel((SingleInputGate)gate, i, new ResultPartitionID(), (ConnectionID)Mockito.mock(ConnectionID.class), connManager, 0, 0, (TaskIOMetricGroup)new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
            gate.setInputChannel(new IntermediateResultPartitionID(), (InputChannel)channel);
        }
        channels[11].onBuffer(mockBuffer, 0);
        channelSequenceNums[11] = channelSequenceNums[11] + 1;
        for (i = 0; i < 999; ++i) {
            Assert.assertNotNull((Object)gate.getNextBufferOrEvent());
            int min = Integer.MAX_VALUE;
            int max = 0;
            for (RemoteInputChannel channel : channels) {
                int size = channel.getNumberOfQueuedBuffers();
                min = Math.min(min, size);
                max = Math.max(max, size);
            }
            Assert.assertTrue((max == min || max == min + 1 ? 1 : 0) != 0);
            if (i % 74 != 0) continue;
            this.fillRandom(channels, channelSequenceNums, 3, mockBuffer);
        }
    }

    private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, Buffer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        for (Integer i : poss) {
            partitions[i].add(buffer);
        }
    }

    private void fillRandom(RemoteInputChannel[] partitions, int[] sequenceNumbers, int numPerPartition, Buffer buffer) throws Exception {
        ArrayList<Integer> poss = new ArrayList<Integer>(partitions.length * numPerPartition);
        for (int i = 0; i < partitions.length; ++i) {
            for (int k = 0; k < numPerPartition; ++k) {
                poss.add(i);
            }
        }
        Collections.shuffle(poss);
        Iterator iterator = poss.iterator();
        while (iterator.hasNext()) {
            int i;
            int n = i = ((Integer)iterator.next()).intValue();
            int n2 = sequenceNumbers[n];
            sequenceNumbers[n] = n2 + 1;
            partitions[i].onBuffer(buffer, n2);
        }
    }

    private static class FairnessVerifyingInputGate
    extends SingleInputGate {
        private final ArrayDeque<InputChannel> channelsWithData;
        private final HashSet<InputChannel> uniquenessChecker;

        public FairnessVerifyingInputGate(String owningTaskName, JobID jobId, IntermediateDataSetID consumedResultId, int consumedSubpartitionIndex, int numberOfInputChannels, TaskActions taskActions, TaskIOMetricGroup metrics) {
            super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED, consumedSubpartitionIndex, numberOfInputChannels, taskActions, metrics);
            try {
                Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
                f.setAccessible(true);
                this.channelsWithData = (ArrayDeque)f.get((Object)this);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.uniquenessChecker = new HashSet();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException {
            ArrayDeque<InputChannel> arrayDeque = this.channelsWithData;
            synchronized (arrayDeque) {
                Assert.assertTrue((String)"too many input channels", (this.channelsWithData.size() <= this.getNumberOfInputChannels() ? 1 : 0) != 0);
                this.ensureUnique(this.channelsWithData);
            }
            return super.getNextBufferOrEvent();
        }

        private void ensureUnique(Collection<InputChannel> channels) {
            HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
            for (InputChannel channel : channels) {
                if (uniquenessChecker.add(channel)) continue;
                Assert.fail((String)("Duplicate channel in input gate: " + channel));
            }
            Assert.assertTrue((String)"found duplicate input channels", (uniquenessChecker.size() == channels.size() ? 1 : 0) != 0);
            uniquenessChecker.clear();
        }
    }
}

