package com.xdja.prs.authentication.gateway;

import com.google.common.io.ByteArrayDataInput;
import com.google.common.io.ByteArrayDataOutput;
import com.google.common.io.ByteStreams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;

import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.SocketException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * Project Name：prs-authentication<br/>
 * ClassName： com.xdja.prs.authentication.gateway.UdpServer<br/>
 * Description：<br/>
 *
 * @author: 黄地
 * @date: 2015/09/06 16:29
 * note:
 */
public class UdpServer {
    private Logger logger = LoggerFactory.getLogger(getClass());
    private final DatagramSocket datagramSocket;
    private final UdpCallback callback;
    private final ExecutorService executorService = Executors.newSingleThreadExecutor();

    public UdpServer(int gatewayReceivePort, UdpCallback callback) throws SocketException {
        this.datagramSocket = new DatagramSocket(gatewayReceivePort);
        this.callback = callback;
    }

    /**
     * 方法描述：启动网关上下线通知监听服务器
     *
     * @author: 黄地
     * @date: 2015-08-24 15:02:28
     */
    public void startOnlineInfoListenServer() {
        logger.debug("启动上下线监听服务,端口:{}", datagramSocket.getLocalPort());
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                byte[] recvBuf = new byte[1024];
                DatagramPacket recvPacket = new DatagramPacket(recvBuf, recvBuf.length);
                try {
                    while (!Thread.currentThread().isInterrupted()) {
                        datagramSocket.receive(recvPacket);
                        logger.debug("收到网关{}的通知信息", recvPacket.getAddress().getHostAddress());
                        try {
                            byte[] pack = Arrays.copyOf(recvPacket.getData(), recvPacket.getLength());
                            processReceiveUdpPack(recvPacket.getAddress(), recvPacket.getPort(), pack);
                        } catch (Exception e) {
                            logger.warn("处理收到的数据出错！" + e.getMessage(), e);
                        }
                    }
                } catch (IOException e) {
                    logger.error("网关数据接收服务停止:" + e.getMessage(), e);
                }
            }
        });
    }

    /**
     * 协议处理参考《udp上下线通知接口.doc》
     *
     * @param address
     * @param port
     * @param pack
     */
    private void processReceiveUdpPack(InetAddress address, int port, byte[] pack) {
        ByteArrayDataInput input = ByteStreams.newDataInput(pack);
        //是否网关ip
        //包头是否匹配
        Assert.isTrue(input.readUnsignedByte() == 0xc5, "pkt_begin != 0xc5");
        int type = input.readUnsignedByte();
        Assert.isTrue(input.readUnsignedByte() == 0x02, "version != 0x02");
        int pktLen = input.readUnsignedShort();

        //根据类型区分处理
        if (type == 0x91 || type == 0x90) {
            if (logger.isDebugEnabled()) {
                if (type == 0x91) {
                    //上线 0x91
                    logger.debug("通知类型:上线");
                } else if (type == 0x90) {
                    //下线 0x90
                    logger.debug("通知类型:下线");
                }
            }
            //上线 0x91 下线 0x90
            int markType = input.readUnsignedByte();
            int markLen = input.readUnsignedByte();
            String ip = null;
            Integer forwardPort = null;
            if (markType == 0x00) {
                //没有连接标识
                logger.debug("连接标识类型:无");
            } else if (markType == 0x01) {
                //虚拟ip
                logger.debug("连接标识类型:虚拟ip");
                ip = getVirtualIp(input, markLen);
                logger.debug("获取到虚拟IP:{}", ip);
            } else if (markType == 0x02) {
                //本地端口
                logger.debug("连接标识类型:本地端口");
                forwardPort = getForwardPort(input, markLen);
                logger.debug("获取到本地端口:{}", forwardPort);
            }
            int snLen = input.readUnsignedByte();
            logger.debug("获取sn长度:{}", snLen);
            String sn = getSnFromPack(input, snLen);
            logger.debug("获取到sn:{}", sn);
            if (type == 0x91) {
                //上线 0x91
                callback.online(address.getHostAddress(), ip, forwardPort, sn);
            } else if (type == 0x90) {
                //下线 0x90
                callback.offline(address.getHostAddress(), ip, forwardPort, sn);
            }
        } else if (type == 0x8e) {
            logger.debug("通知类型:网关启停");
            //网关启动/关闭，清除在线信息
            callback.allOffline(address.getHostAddress());
        } else {
            logger.debug("未知的通知类型:{}", Integer.toHexString(type));
        }
        int pktEnd = input.readUnsignedByte();
        Assert.isTrue(pktEnd == 0x5c, "pkt_end = 0x5c");
    }

    /**
     * 获取sn号<br/>
     * 每个字节对应一个16进制的值
     *
     * @param input
     * @param snLen
     * @return
     */
    private String getSnFromPack(ByteArrayDataInput input, int snLen) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < snLen; i++) {
            String s = Integer.toHexString(input.readUnsignedByte());
            if (s.length() == 1) {
                sb.append('0');
            }
            sb.append(s);
        }
        if (sb.charAt(0) == '0') {
            sb.deleteCharAt(0);
        }
        return sb.toString();
    }

    /**
     * 获取转发端口<br/>
     * 对应一个16位的无符号整数
     *
     * @param input
     * @param markLen
     * @return
     */
    private Integer getForwardPort(ByteArrayDataInput input, int markLen) {
        int i = input.readUnsignedShort();
        return i;
    }

    /**
     * 获取虚拟ip<br/>
     * 每个字节对应一个8位的无符号数字。
     *
     * @param input
     * @param markLen
     * @return
     */
    private String getVirtualIp(ByteArrayDataInput input, int markLen) {
        return input.readUnsignedByte() + "." + input.readUnsignedByte() + "." + input.readUnsignedByte() + "." + input.readUnsignedByte();
    }


    /**
     * 方法描述：获取当前在线信息
     *
     * @param gatewayPortList
     * @author: 黄地
     * @date: 2015-08-24 15:01:42
     */
    public void startGetCurrentOnlineInfo(List<String> gatewayIpList, List<Integer> gatewayPortList) {
        ByteArrayDataOutput out = ByteStreams.newDataOutput();
        out.writeByte(0xc5);
        out.writeByte(0x8f);
        out.writeByte(0x02);
        out.writeShort(0x06);
        out.writeByte(0x5c);
        byte[] bytes = out.toByteArray();
        for (int i = 0; i < gatewayIpList.size(); i++) {
            String gateway = null;
            Integer port = null;
            try {
                gateway = gatewayIpList.get(i);
                port = getGatewayPort(gatewayPortList, i);
                DatagramPacket pack = new DatagramPacket(bytes, bytes.length, InetAddress.getByName(gateway), port);
                logger.debug("发送获取当前在线列表的请求给{}:{}", gateway, port);
                datagramSocket.send(pack);
            } catch (IOException e) {
                logger.warn("给{}:{}发送获取当前在线通知信息失败！" + e.getMessage(), gateway, port, e);
            }
        }
    }

    private Integer getGatewayPort(List<Integer> gatewayPortList, int port) {
        return gatewayPortList.size() == 1 ? gatewayPortList.get(0) : gatewayPortList.get(port);
    }

    public void shutdown() throws Exception {
        executorService.shutdownNow();
    }
}
