package com.xdja.pki.gmssl.operator.utils;

import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.util.Arrays;

import java.security.SecureRandom;

/**
 * @description: TODO
 * @author: feng zhen
 * @date: 2020/6/2 14:00
 **/
public class GMSSLPKCS1Encodeing {
    private static final int HEADER_LENGTH = 10;

    private static int getEngineInputBlockSize(boolean forEncryption, int bitSize) {
        return forEncryption ? (bitSize + 7) / 8 - 1 : (bitSize + 7) / 8;
    }

    private static int getEngineOutputBlockSize(boolean forEncryption, int bitSize) {
        return forEncryption ? (bitSize + 7) / 8 : (bitSize + 7) / 8 - 1;
    }

    private static int getInputBlockSize(boolean forEncryption, int bitSize) {
        int baseBlockSize = forEncryption ? (bitSize + 7) / 8 - 1 : (bitSize + 7) / 8;
        if (forEncryption) {
            return baseBlockSize - HEADER_LENGTH;
        } else {
            return baseBlockSize;
        }
    }

    private static int getOutputBlockSize(boolean forEncryption, int bitSize) {
        int baseBlockSize = forEncryption ? (bitSize + 7) / 8 : (bitSize + 7) / 8 - 1;

        if (forEncryption) {
            return baseBlockSize;
        } else {
            return baseBlockSize - HEADER_LENGTH;
        }
    }

    public static byte[] encodeBlock(byte[] in, int inOff, int inLen, int keyBits) {
        if (inLen > getInputBlockSize(true, keyBits)) {
            throw new IllegalArgumentException("input data too large");
        }
        SecureRandom random = new SecureRandom();
        byte[] block = new byte[getEngineInputBlockSize(true, keyBits) + 1];
        // random fill
        random.nextBytes(block);
        // type code 2
        block[0] = 0x00;
        block[1] = 0x02;
        //
        // a zero byte marks the end of the padding, so all
        // the pad bytes must be non-zero.
        //
        for (int i = 2; i != block.length - inLen - 2; i++) {
            while (block[i] == 0) {
                block[i] = (byte) random.nextInt();
            }
        }
        // mark the end of the padding
        block[block.length - inLen - 1] = 0x00;
        System.arraycopy(in, inOff, block, block.length - inLen, inLen);
        return block;
    }

    public static byte[] decodeBlock(byte[] block, int keyBits) throws InvalidCipherTextException {
        /*
         * If the length of the expected plaintext is known, we use a constant-time decryption.
         * If the decryption fails, we return a random value.
         */
        byte[] data = new byte[keyBits / 8 - 1];
        System.arraycopy(block, 1, data, 0, data.length);
        byte type = data[0];
        //
        // find and extract the message block.
        //
        int start = findStart(type, data);
        start++;           // data should start at the next byte
        if ( start < HEADER_LENGTH) {
            Arrays.fill(data, (byte) 0);
            throw new InvalidCipherTextException("block incorrect");
        }
        // if we get this far, it's likely to be a genuine encoding error
        byte[] result = new byte[data.length - start];
        System.arraycopy(data, start, result, 0, result.length);
        return result;
    }

    public static byte[] encodePrivateBlock(byte[] in, int inOff, int inLen, int keyBits) {
        if (inLen > getInputBlockSize(true, keyBits)) {
            throw new IllegalArgumentException("input data too large");
        }
        SecureRandom random = new SecureRandom();
        byte[] block = new byte[getEngineInputBlockSize(true, keyBits) + 1];
        // random fill
        random.nextBytes(block);
        // type code 2
        block[0] = 0x00;
        block[1] = 0x01;
        //
        // a zero byte marks the end of the padding, so all
        // the pad bytes must be non-zero.
        //
        for (int i = 2; i != block.length - inLen - 1; i++) {
            block[i] = (byte) 0xFF;
        }
        // mark the end of the padding
        block[block.length - inLen - 1] = 0x00;
        System.arraycopy(in, inOff, block, block.length - inLen, inLen);
        return block;
    }

    private static int findStart(byte type, byte[] block) {
        int start = -1;
        boolean padErr = false;

        for (int i = 1; i != block.length; i++) {
            byte pad = block[i];

            if (pad == 0 & start < 0) {
                start = i;
            }
            padErr |= (type == 1 & start < 0 & pad != (byte) 0xff);
        }

        if (padErr) {
            return -1;
        }

        return start;
    }
}
