/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.io.network.api.writer;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.util.DeserializationUtils;
import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
import org.apache.flink.testutils.serialization.types.SerializationTestType;
import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory;
import org.apache.flink.testutils.serialization.types.Util;
import org.apache.flink.types.IntValue;
import org.apache.flink.util.XORShiftRandom;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;

public class RecordWriterTest {
    @Rule
    public TemporaryFolder tempFolder = new TemporaryFolder();

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testClearBuffersAfterInterruptDuringBlockingBufferRequest() throws Exception {
        ExecutorService executor = null;
        try {
            executor = Executors.newSingleThreadExecutor();
            final CountDownLatch sync = new CountDownLatch(2);
            final TrackingBufferRecycler recycler = new TrackingBufferRecycler();
            final MemorySegment memorySegment = MemorySegmentFactory.allocateUnpooledSegment((int)4);
            Answer<BufferBuilder> request = new Answer<BufferBuilder>(){

                public BufferBuilder answer(InvocationOnMock invocation) throws Throwable {
                    Object o;
                    sync.countDown();
                    if (sync.getCount() == 1L) {
                        return new BufferBuilder(memorySegment, (BufferRecycler)recycler);
                    }
                    Object object = o = new Object();
                    synchronized (object) {
                        while (true) {
                            o.wait();
                        }
                    }
                }
            };
            BufferProvider bufferProvider = (BufferProvider)Mockito.mock(BufferProvider.class);
            Mockito.when((Object)bufferProvider.requestBufferBuilderBlocking()).thenAnswer((Answer)request);
            RecyclingPartitionWriter partitionWriter = new RecyclingPartitionWriter(bufferProvider);
            final RecordWriter recordWriter = new RecordWriter((ResultPartitionWriter)partitionWriter);
            Future<Void> result = executor.submit(new Callable<Void>(){

                @Override
                public Void call() throws Exception {
                    IntValue val = new IntValue(0);
                    try {
                        recordWriter.emit((IOReadableWritable)val);
                        recordWriter.flushAll();
                        recordWriter.emit((IOReadableWritable)val);
                    }
                    catch (InterruptedException e) {
                        recordWriter.clearBuffers();
                    }
                    return null;
                }
            });
            sync.await();
            result.cancel(true);
            recordWriter.clearBuffers();
            ((BufferProvider)Mockito.verify((Object)bufferProvider, (VerificationMode)Mockito.times((int)2))).requestBufferBuilderBlocking();
            Assert.assertEquals((long)1L, (long)recycler.getRecycledMemorySegments().size());
            Assert.assertEquals((Object)memorySegment, (Object)recycler.getRecycledMemorySegments().get(0));
        }
        finally {
            if (executor != null) {
                executor.shutdown();
            }
        }
    }

    @Test
    public void testSerializerClearedAfterClearBuffers() throws Exception {
        ResultPartitionWriter partitionWriter = (ResultPartitionWriter)Mockito.spy((Object)new RecyclingPartitionWriter(new TestPooledBufferProvider(1, 16)));
        RecordWriter recordWriter = new RecordWriter(partitionWriter);
        recordWriter.emit((IOReadableWritable)new IntValue(0));
        recordWriter.clearBuffers();
        recordWriter.flushAll();
    }

    @Test
    public void testBroadcastEventNoRecords() throws Exception {
        int numChannels = 4;
        int bufferSize = 32;
        Queue[] queues = new Queue[numChannels];
        for (int i = 0; i < numChannels; ++i) {
            queues[i] = new ArrayDeque();
        }
        TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);
        CollectingPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
        RecordWriter writer = new RecordWriter((ResultPartitionWriter)partitionWriter, new RoundRobin());
        CheckpointBarrier barrier = new CheckpointBarrier(2148402839L, 2166311875L, CheckpointOptions.forCheckpointWithDefaultLocation());
        writer.broadcastEvent((AbstractEvent)barrier);
        Assert.assertEquals((long)0L, (long)bufferProvider.getNumberOfCreatedBuffers());
        for (int i = 0; i < numChannels; ++i) {
            Assert.assertEquals((long)1L, (long)queues[i].size());
            BufferOrEvent boe = RecordWriterTest.parseBuffer((BufferConsumer)queues[i].remove(), i);
            Assert.assertTrue((boolean)boe.isEvent());
            Assert.assertEquals((Object)barrier, (Object)boe.getEvent());
            Assert.assertEquals((long)0L, (long)queues[i].size());
        }
    }

    @Test
    public void testBroadcastEventMixedRecords() throws Exception {
        XORShiftRandom rand = new XORShiftRandom();
        int numChannels = 4;
        int bufferSize = 32;
        int lenBytes = 4;
        Queue[] queues = new Queue[numChannels];
        for (int i = 0; i < numChannels; ++i) {
            queues[i] = new ArrayDeque();
        }
        TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);
        CollectingPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
        RecordWriter writer = new RecordWriter((ResultPartitionWriter)partitionWriter, new RoundRobin());
        CheckpointBarrier barrier = new CheckpointBarrier(2147484939L, 2147483846L, CheckpointOptions.forCheckpointWithDefaultLocation());
        byte[] bytes = new byte[bufferSize / 2];
        rand.nextBytes(bytes);
        writer.emit((IOReadableWritable)new ByteArrayIO(bytes));
        bytes = new byte[bufferSize + 1];
        rand.nextBytes(bytes);
        writer.emit((IOReadableWritable)new ByteArrayIO(bytes));
        bytes = new byte[bufferSize - lenBytes];
        rand.nextBytes(bytes);
        writer.emit((IOReadableWritable)new ByteArrayIO(bytes));
        writer.broadcastEvent((AbstractEvent)barrier);
        Assert.assertEquals((long)4L, (long)bufferProvider.getNumberOfCreatedBuffers());
        Assert.assertEquals((long)2L, (long)queues[0].size());
        Assert.assertTrue((boolean)RecordWriterTest.parseBuffer((BufferConsumer)queues[0].remove(), 0).isBuffer());
        Assert.assertEquals((long)3L, (long)queues[1].size());
        Assert.assertTrue((boolean)RecordWriterTest.parseBuffer((BufferConsumer)queues[1].remove(), 1).isBuffer());
        Assert.assertTrue((boolean)RecordWriterTest.parseBuffer((BufferConsumer)queues[1].remove(), 1).isBuffer());
        Assert.assertEquals((long)2L, (long)queues[2].size());
        Assert.assertTrue((boolean)RecordWriterTest.parseBuffer((BufferConsumer)queues[2].remove(), 2).isBuffer());
        Assert.assertEquals((long)1L, (long)queues[3].size());
        for (int i = 0; i < numChannels; ++i) {
            BufferOrEvent boe = RecordWriterTest.parseBuffer((BufferConsumer)queues[i].remove(), i);
            Assert.assertTrue((boolean)boe.isEvent());
            Assert.assertEquals((Object)barrier, (Object)boe.getEvent());
        }
    }

    @Test
    public void testBroadcastEventBufferReferenceCounting() throws Exception {
        Queue[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()};
        CollectingPartitionWriter partition = new CollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE));
        RecordWriter writer = new RecordWriter((ResultPartitionWriter)partition);
        writer.broadcastEvent((AbstractEvent)EndOfPartitionEvent.INSTANCE);
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[0]).size());
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[1]).size());
        BufferConsumer bufferConsumer1 = (BufferConsumer)((ArrayDeque)queues[0]).getFirst();
        BufferConsumer bufferConsumer2 = (BufferConsumer)((ArrayDeque)queues[1]).getFirst();
        for (int i = 0; i < queues.length; ++i) {
            Assert.assertTrue((boolean)RecordWriterTest.parseBuffer((BufferConsumer)((ArrayDeque)queues[i]).remove(), i).isEvent());
        }
        Assert.assertTrue((boolean)bufferConsumer1.isRecycled());
        Assert.assertTrue((boolean)bufferConsumer2.isRecycled());
    }

    @Test
    public void testBroadcastEventBufferIndependence() throws Exception {
        Queue[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()};
        CollectingPartitionWriter partition = new CollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE));
        RecordWriter writer = new RecordWriter((ResultPartitionWriter)partition);
        writer.broadcastEvent((AbstractEvent)EndOfPartitionEvent.INSTANCE);
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[0]).size());
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[1]).size());
        Buffer buffer1 = BufferBuilderTestUtils.buildSingleBuffer((BufferConsumer)((ArrayDeque)queues[0]).remove());
        Buffer buffer2 = BufferBuilderTestUtils.buildSingleBuffer((BufferConsumer)((ArrayDeque)queues[1]).remove());
        Assert.assertEquals((long)0L, (long)buffer1.getReaderIndex());
        Assert.assertEquals((long)0L, (long)buffer2.getReaderIndex());
        buffer1.setReaderIndex(1);
        Assert.assertEquals((String)"Buffer 2 shares the same reader index as buffer 1", (long)0L, (long)buffer2.getReaderIndex());
    }

    @Test
    public void testBroadcastEmitBufferIndependence() throws Exception {
        Queue[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()};
        CollectingPartitionWriter partition = new CollectingPartitionWriter(queues, new TestPooledBufferProvider(Integer.MAX_VALUE));
        RecordWriter writer = new RecordWriter((ResultPartitionWriter)partition);
        writer.broadcastEmit((IOReadableWritable)new IntValue(0));
        writer.flushAll();
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[0]).size());
        Assert.assertEquals((long)1L, (long)((ArrayDeque)queues[1]).size());
        Buffer buffer1 = BufferBuilderTestUtils.buildSingleBuffer((BufferConsumer)((ArrayDeque)queues[0]).remove());
        Buffer buffer2 = BufferBuilderTestUtils.buildSingleBuffer((BufferConsumer)((ArrayDeque)queues[1]).remove());
        Assert.assertEquals((long)0L, (long)buffer1.getReaderIndex());
        Assert.assertEquals((long)0L, (long)buffer2.getReaderIndex());
        buffer1.setReaderIndex(1);
        Assert.assertEquals((String)"Buffer 2 shares the same reader index as buffer 1", (long)0L, (long)buffer2.getReaderIndex());
    }

    @Test
    public void testEmitRecordWithBroadcastPartitioner() throws Exception {
        this.emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(false);
    }

    @Test
    public void testBroadcastEmitRecord() throws Exception {
        this.emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(true);
    }

    private void emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(boolean isBroadcastEmit) throws Exception {
        int numChannels = 4;
        int bufferSize = 32;
        int numValues = 8;
        int serializationLength = 4;
        Queue[] queues = new Queue[4];
        for (int i = 0; i < 4; ++i) {
            queues[i] = new ArrayDeque();
        }
        TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, 32);
        CollectingPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider);
        RecordWriter writer = isBroadcastEmit ? new RecordWriter((ResultPartitionWriter)partitionWriter) : new RecordWriter((ResultPartitionWriter)partitionWriter, new Broadcast());
        SpillingAdaptiveSpanningRecordDeserializer deserializer = new SpillingAdaptiveSpanningRecordDeserializer(new String[]{this.tempFolder.getRoot().getAbsolutePath()});
        ArrayDeque<SerializationTestType> serializedRecords = new ArrayDeque<SerializationTestType>();
        Util.MockRecords records = Util.randomRecords((int)8, (SerializationTestTypeFactory)SerializationTestTypeFactory.INT);
        for (SerializationTestType record : records) {
            serializedRecords.add(record);
            if (isBroadcastEmit) {
                writer.broadcastEmit((IOReadableWritable)record);
                continue;
            }
            writer.emit((IOReadableWritable)record);
        }
        int requiredBuffers = 2;
        for (int i = 0; i < 4; ++i) {
            Assert.assertEquals((long)2L, (long)queues[i].size());
            Object expectedRecords = serializedRecords.clone();
            int assertRecords = 0;
            for (int j = 0; j < 2; ++j) {
                Buffer buffer = BufferBuilderTestUtils.buildSingleBuffer((BufferConsumer)queues[i].remove());
                deserializer.setNextBuffer(buffer);
                assertRecords += DeserializationUtils.deserializeRecords((ArrayDeque<SerializationTestType>)expectedRecords, (RecordDeserializer<SerializationTestType>)deserializer);
            }
            Assert.assertEquals((long)8L, (long)assertRecords);
        }
    }

    private static BufferOrEvent parseBuffer(BufferConsumer bufferConsumer, int targetChannel) throws IOException {
        Buffer buffer = BufferBuilderTestUtils.buildSingleBuffer(bufferConsumer);
        if (buffer.isBuffer()) {
            return new BufferOrEvent(buffer, targetChannel);
        }
        AbstractEvent event = EventSerializer.fromBuffer((Buffer)buffer, (ClassLoader)RecordWriterTest.class.getClassLoader());
        buffer.recycleBuffer();
        return new BufferOrEvent(event, targetChannel);
    }

    private static class TrackingBufferRecycler
    implements BufferRecycler {
        private final ArrayList<MemorySegment> recycledMemorySegments = new ArrayList();

        private TrackingBufferRecycler() {
        }

        public synchronized void recycle(MemorySegment memorySegment) {
            this.recycledMemorySegments.add(memorySegment);
        }

        public synchronized List<MemorySegment> getRecycledMemorySegments() {
            return this.recycledMemorySegments;
        }
    }

    private static class Broadcast<T extends IOReadableWritable>
    implements ChannelSelector<T> {
        private int[] returnChannel;

        private Broadcast() {
        }

        public int[] selectChannels(T record, int numberOfOutputChannels) {
            if (this.returnChannel != null && this.returnChannel.length == numberOfOutputChannels) {
                return this.returnChannel;
            }
            this.returnChannel = new int[numberOfOutputChannels];
            for (int i = 0; i < numberOfOutputChannels; ++i) {
                this.returnChannel[i] = i;
            }
            return this.returnChannel;
        }
    }

    private static class RoundRobin<T extends IOReadableWritable>
    implements ChannelSelector<T> {
        private int[] nextChannel = new int[]{-1};

        private RoundRobin() {
        }

        public int[] selectChannels(T record, int numberOfOutputChannels) {
            this.nextChannel[0] = (this.nextChannel[0] + 1) % numberOfOutputChannels;
            return this.nextChannel;
        }
    }

    private static class ByteArrayIO
    implements IOReadableWritable {
        private final byte[] bytes;

        public ByteArrayIO(byte[] bytes) {
            this.bytes = bytes;
        }

        public void write(DataOutputView out) throws IOException {
            out.write(this.bytes);
        }

        public void read(DataInputView in) throws IOException {
            in.readFully(this.bytes);
        }
    }

    private static class RecyclingPartitionWriter
    implements ResultPartitionWriter {
        private final BufferProvider bufferProvider;
        private final ResultPartitionID partitionId = new ResultPartitionID();

        private RecyclingPartitionWriter(BufferProvider bufferProvider) {
            this.bufferProvider = bufferProvider;
        }

        public BufferProvider getBufferProvider() {
            return this.bufferProvider;
        }

        public ResultPartitionID getPartitionId() {
            return this.partitionId;
        }

        public int getNumberOfSubpartitions() {
            return 1;
        }

        public int getNumTargetKeyGroups() {
            return 1;
        }

        public void addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException {
            bufferConsumer.close();
        }

        public void flushAll() {
        }

        public void flush(int subpartitionIndex) {
        }
    }

    private static class CollectingPartitionWriter
    implements ResultPartitionWriter {
        private final Queue<BufferConsumer>[] queues;
        private final BufferProvider bufferProvider;
        private final ResultPartitionID partitionId = new ResultPartitionID();

        private CollectingPartitionWriter(Queue<BufferConsumer>[] queues, BufferProvider bufferProvider) {
            this.queues = queues;
            this.bufferProvider = bufferProvider;
        }

        public BufferProvider getBufferProvider() {
            return this.bufferProvider;
        }

        public ResultPartitionID getPartitionId() {
            return this.partitionId;
        }

        public int getNumberOfSubpartitions() {
            return this.queues.length;
        }

        public int getNumTargetKeyGroups() {
            return 1;
        }

        public void addBufferConsumer(BufferConsumer buffer, int targetChannel) throws IOException {
            this.queues[targetChannel].add(buffer);
        }

        public void flushAll() {
        }

        public void flush(int subpartitionIndex) {
        }
    }
}

