package com.xdja.pki.ldap.dao;

import com.xdja.pki.ldap.X509Utils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.provider.X509CertPairParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ldap.NameNotFoundException;
import org.springframework.ldap.core.AttributesMapper;
import org.springframework.ldap.core.ContextMapper;
import org.springframework.ldap.core.DirContextAdapter;
import org.springframework.ldap.core.LdapTemplate;
import org.springframework.ldap.core.support.DefaultDirObjectFactory;
import org.springframework.ldap.core.support.LdapContextSource;
import org.springframework.ldap.query.LdapQueryBuilder;
import sun.security.provider.certpath.X509CertificatePair;

import javax.naming.Context;
import javax.naming.NamingEnumeration;
import javax.naming.directory.*;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
import java.io.ByteArrayInputStream;
import java.security.Security;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509CRL;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;

import static org.springframework.ldap.query.LdapQueryBuilder.query;

public class SpringLDAPConnect {

    static {
        if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) {
            Security.addProvider(new BouncyCastleProvider());
        }
    }

    private LdapTemplate ldapTemplate;
    private String containerName;

    private Logger logger = LoggerFactory.getLogger(this.getClass());


    public SpringLDAPConnect(String ldapHost, int ldapPort, String loginDN, String password, String containerName) {
        try {
            LdapContextSource lcs = new LdapContextSource();
            lcs.setUrl("ldap://" + ldapHost + ":" + ldapPort);
            lcs.setUserDn(loginDN);
            lcs.setPassword(password);
            lcs.setDirObjectFactory(DefaultDirObjectFactory.class);
            lcs.afterPropertiesSet();
            this.ldapTemplate = new LdapTemplate(lcs);
            this.containerName = containerName;
            logger.info("连接ldap服务器成功");
        } catch (Exception e) {
            logger.error("连接ldap服务器失败");
            e.printStackTrace();
        }
    }


    public void deleteAll() {
        try {
            ldapTemplate.unbind(containerName, true);
            logger.info("-------已清空服务器 " + containerName + "节点下所有数据---------");
        } catch (org.springframework.ldap.NamingException ne) {
            // if entry not exist, it will be throw this exception
            // do nothing
        }
    }

    public X509CRL searchCrlEntry(String dn, String name) throws Exception {
        dn = dn.replace("dnqualifier".toUpperCase(), "displayName".toUpperCase());
        dn = dn.replace(",e".toUpperCase(), ",mail".toUpperCase());
        if (!dn.toUpperCase().endsWith(this.containerName.toUpperCase())) {
            dn = dn + "," + this.containerName;
        }
        logger.info("search dn is {}", dn);
        Attributes attributes;
        try {
            attributes = ldapTemplate.searchForObject(LdapQueryBuilder.query().base(dn)
                            .filter("(objectClass=*)"),
                    new ContextMapper<Attributes>() {
                        @Override
                        public Attributes mapFromContext(Object ctx) {
                            return ((DirContextAdapter) ctx).getAttributes();
                        }
                    });
        } catch (NameNotFoundException e) {
            //没有该节点
            return null;
        }

        CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509", BouncyCastleProvider.PROVIDER_NAME);
        NamingEnumeration<String> iDs = attributes.getIDs();
        while (iDs.hasMore()) {
            String id = iDs.next();
            if (id.equalsIgnoreCase(name)) {
                NamingEnumeration<?> all = attributes.get(id).getAll();
                while (all.hasMore()) {
                    return (X509CRL) certificateFactory.generateCRL(new ByteArrayInputStream((byte[]) all.next()));
                }
            }
        }
        logger.error("没有查到该节点：" + name);
        throw new Exception("没有查到该节点：" + name);
    }

    public void addCACertEntry(String dn, String cn, byte[] certBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        BasicAttribute attr = new BasicAttribute("objectclass");
        attr.add("pkiCA");
        attr.add("organizationalRole");
        attrs.put(attr);
        attrs.put("cn", cn);
        attrs.put("cACertificate;binary", certBinary);
        logger.debug("开始插入CA证书");
        addEntry(dn, attrs);
    }

    public void addUserCertEntry(String dn, String cn, byte[] certBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        attrs.put("objectclass", "inetOrgPerson");
        attrs.put("userCertificate;binary", certBinary);
        attrs.put("cn", cn);
        attrs.put("sn", cn);
        logger.debug("开始插入用户证书");
        addEntry(dn, attrs);
    }

    public void addARLEntry(String dn, String cn, byte[] crlBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        attrs.put("objectclass", "cRLDistributionPoint");
        attrs.put("cn", cn);
        attrs.put("authorityRevocationList;binary", crlBinary);
        logger.debug("开始插入arl");
        addEntry(dn, attrs);
    }

    public void addDRLEntry(String dn, String cn, byte[] crlBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        attrs.put("objectclass", "cRLDistributionPoint");
        attrs.put("cn", cn);
        attrs.put("deltaRevocationList;binary", crlBinary);
        logger.debug("开始插入drl");
        addEntry(dn, attrs);
    }

    public void addCRLEntry(String dn, String cn, byte[] crlBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        attrs.put("objectclass", "cRLDistributionPoint");
        attrs.put("cn", cn);
        attrs.put("certificateRevocationList;binary", crlBinary);
        logger.debug("开始插入crl");
        addEntry(dn, attrs);
    }


    public void addCrossCertEntry(String dn, String cn, byte[] pairBinary) throws Exception {
        dn = dn.replace("dnqualifier".toUpperCase(), "displayName".toUpperCase());
        dn = dn.replace(",e".toUpperCase(), ",mail".toUpperCase());
        if (selectCrossCert(dn) == null) {
            logger.debug("该节点没有交叉证书");
            addCrossCert(dn, cn, pairBinary);
        } else {
            logger.error("该节点已经有一个交叉证书，无法继续插入");
            throw new Exception("该节点已经有一个交叉证书，无法继续插入");
        }
    }

    public void addForwardCert(String dn, String cn, X509Certificate forward) throws Exception {
        //存在该节点
        dn = dn.replace("dnqualifier".toUpperCase(), "displayName".toUpperCase());
        dn = dn.replace(",e".toUpperCase(), ",mail".toUpperCase());
        boolean b = checkExist(dn);
        if (b) {
            if (selectCrossCert(dn) == null) {
                logger.debug("说明该节点没有交叉证书,可以插入交叉证书forward");
                addCrossCert(dn, cn, new X509CertificatePair(forward, null).getEncoded());
            }
            X509CertificatePair crossPair = selectCrossCert(dn);
            if (crossPair.getForward() != null) {
                logger.error("该节点已经存在交叉证书的forward,不能继续插入");
                throw new Exception("该节点已经存在交叉证书的forward,不能继续插入");
            }
            if (crossPair.getReverse() != null) {
                try {
                    X509CertificatePair pair = new X509CertificatePair(forward, crossPair.getReverse());
                    byte[] pairBinary = pair.getEncoded();
                    logger.debug("开始插入交叉证书forward");
                    addCrossCert(dn, cn, pairBinary);
                } catch (CertificateException e) {
                    logger.error("crossPair.getForward().getIssuerX500Principal() " + forward.getSubjectX500Principal());
                    logger.error("crossPair.getReverse().getIssuerX500Principal() " + crossPair.getReverse().getIssuerX500Principal());
                    logger.error("该证书无法与服务器的Forward证书构成交叉证书对", e);
                    throw new Exception("can't use these two certs build X509CertificatePair");
                }
            }
        } else {
            logger.debug("说明不存在该节点,可以插入交叉证书forward");
            addCrossCert(dn, cn, new X509CertificatePair(forward, null).getEncoded());
        }
    }

    public void addReserveCert(String dn, String cn, X509Certificate reserve) throws Exception {
        dn = dn.replace("dnqualifier".toUpperCase(), "displayName".toUpperCase());
        dn = dn.replace(",e".toUpperCase(), ",mail".toUpperCase());
        //存在该节点
        boolean b = checkExist(dn);
        if (b) {
            if (selectCrossCert(dn) == null) {
                logger.debug("说明该节点没有交叉证书,可以插入交叉证书reserve");
                addCrossCert(dn, cn, new X509CertificatePair(null, reserve).getEncoded());
            }
            X509CertificatePair crossPair = selectCrossCert(dn);
            if (crossPair.getReverse() != null) {
                logger.error("该节点已经存在交叉证书的reserve,不能继续插入");
                throw new Exception("该节点已经存在交叉证书的reserve,不能继续插入");
            }
            if (crossPair.getForward() != null) {
                try {
                    X509CertificatePair pair = new X509CertificatePair(crossPair.getForward(), reserve);
                    byte[] pairBinary = pair.getEncoded();
                    logger.debug("开始插入交叉证书reserve");
                    addCrossCert(dn, cn, pairBinary);
                } catch (CertificateException e) {
                    logger.error("crossPair.getForward().getIssuerX500Principal() " + crossPair.getForward().getSubjectX500Principal());
                    logger.error("crossPair.getReverse().getIssuerX500Principal() " + reserve.getIssuerX500Principal());
                    logger.error("该证书无法与服务器的Forward证书构成交叉证书对 ", e);
                    throw new Exception("can't use these two certs build X509CertificatePair");
                }
            }
        } else {
            logger.debug("说明不存在该节点,可以插入交叉证书reserve");
            addCrossCert(dn, cn, new X509CertificatePair(null, reserve).getEncoded());
        }

    }

    private void addCrossCert(String dn, String cn, byte[] pairBinary) throws Exception {
        Attributes attrs = new BasicAttributes();
        BasicAttribute attr = new BasicAttribute("objectclass", "organizationalRole");
        attr.add("pkiCA");
        attrs.put(attr);
        attrs.put("cn", cn);
        attrs.put("crossCertificatePair;binary", pairBinary);
        logger.debug("开始插入交叉证书对");
        addEntry(dn, attrs);
    }

    private void addEntry(String dn, Attributes attrs) throws Exception {
        dn = dn.replace("dnqualifier".toUpperCase(), "displayName".toUpperCase());
        dn = dn.replace(",e".toUpperCase(), ",mail".toUpperCase());
        if (!dn.toUpperCase().endsWith("," + containerName.toUpperCase())) {
            dn = dn + "," + containerName;
            logger.debug("证书dn修改后为 " + dn);
        }
        boolean exist = checkExist(dn);
        if (exist) {
            DirContextAdapter old = (DirContextAdapter) ldapTemplate.lookup(dn);
            Object[] objectclass = old.getObjectAttributes("objectclass");
            List<String> oldValue = new ArrayList<>();
            for (int i = 0; i < objectclass.length; i++) {
                oldValue.add((String) objectclass[i]);
            }
            try {
                List<ModificationItem> list = new ArrayList<ModificationItem>();
                List<String> newValue = new ArrayList<>();
                //获取属性中的objectClass
                for (NamingEnumeration ae = attrs.getAll(); ae.hasMoreElements();) {
                    Attribute attr = (Attribute) ae.next();
                    if (attr.getID().contains("objectclass")) {
                        String attrStr = (String) attr.get();
                        if (!newValue.contains(attrStr)) {
                            newValue.add(attrStr);
                        }
                    }
                }
                //兼容CA证书和服务器证书DN重复,下文用pkiUser替换，先去除CN，SN
                if(oldValue.contains("pkiCA") && newValue.contains("inetOrgPerson")) {
                    attrs.remove("cn");
                    attrs.remove("sn");
                }
                for (NamingEnumeration ae = attrs.getAll(); ae.hasMoreElements(); ) {
                    Attribute attr = (Attribute) ae.next();
                    ModificationItem item;
                    if (attr.getID().contains("objectclass")) {
                        String o = (String) attr.get();
                        if (oldValue.contains("person") && oldValue.contains("pkiCA") && o.equalsIgnoreCase("pkiCA")) {
                            continue;
                        }
                        if (oldValue.contains("person") && o.equalsIgnoreCase("pkiCA")) {
                            Attribute attribute = new BasicAttribute("objectclass");
                            attribute.add("pkiCA");
                            item = new ModificationItem(DirContext.ADD_ATTRIBUTE, attribute);
                            list.add(item);
                            continue;
                        }
                        if (oldValue.contains("person") && oldValue.contains("pkiUser")) {
                            continue;
                        }
                        if (oldValue.contains("person")) {
                            Attribute attribute = new BasicAttribute("objectclass");
                            attribute.add("pkiUser");
                            item = new ModificationItem(DirContext.ADD_ATTRIBUTE, attribute);
                            list.add(item);
                            continue;
                        }
                        //兼容CA证书和服务器证书DN重复,用pkiUser替换inetOrgPerson
                        if(oldValue.contains("pkiCA") && o.equalsIgnoreCase("inetOrgPerson")){
                            if(!oldValue.contains("pkiUser")) {
                                Attribute attribute = new BasicAttribute("objectclass");
                                attribute.add("pkiUser");
                                item = new ModificationItem(DirContext.ADD_ATTRIBUTE, attribute);
                                list.add(item);
                            }
                            continue;
                        }
                    }
                    if (attr.getID().contains("userCertificate;binary")
                            || attr.getID().contains("cACertificate;binary")) {
                        //    || attr.getID().contains("crossCertificatePair;binary")) {
                        item = new ModificationItem(DirContext.ADD_ATTRIBUTE, attr);
                    } else {
                        item = new ModificationItem(DirContext.REPLACE_ATTRIBUTE, attr);
                    }
                    list.add(item);
                }
                ModificationItem[] items = new ModificationItem[list.size()];
                ldapTemplate.modifyAttributes(dn, list.toArray(items));
                logger.info("更新 " + dn + " 节点成功");
            } catch (Exception e) {
                if (e.getMessage().contains("LDAP: error code 20")) {
                    logger.info("节点 " + dn + " 数据已存在 不再插入");
                } else {
                    logger.error("修改" + dn + "节点失败", e);
                    throw new Exception("修改" + dn + "节点失败", e);
                }
            }
        } else {
            addPoint(dn, attrs);
        }
    }

    private void addPoint(String dn, Attributes attributes) throws Exception {
        if (!dn.toUpperCase().endsWith("," + containerName.toUpperCase())) {
            dn = dn + "," + containerName;
        }
        LdapName ldapName = new LdapName(dn);
        int posn = 0;
        Rdn rdn = ldapName.getRdn(posn);
        String getdn = rdn + "";
        while (!getdn.equalsIgnoreCase(containerName)) {
            posn++;
            if (posn >= ldapName.size()) {
                logger.error("传入的dn不是以" + containerName + "结尾的");
                throw new Exception("this dn is false " + dn + " " + getdn + " " + containerName);
            }
            rdn = ldapName.getRdn(posn);
            getdn = rdn + "," + getdn;
        }
        for (; posn <= ldapName.size(); posn++) {
            logger.debug("当前得到的dn节点为 " + getdn);
            boolean isExist = checkExist(getdn);
            if (!isExist) {
                if (getdn.equalsIgnoreCase(dn)) {
                    ldapTemplate.bind(getdn, null, attributes);
                    logger.info("添加 " + dn + "节点成功");
                    break;
                } else {
                    Attributes attrs = new BasicAttributes();
                    if (rdn.getType().equalsIgnoreCase("ou")) {
                        attrs.put("objectclass", "organizationalUnit");
                    } else if (rdn.getType().equalsIgnoreCase("o")) {
                        attrs.put("objectclass", "organization");
                    } else if (rdn.getType().equalsIgnoreCase("L")) {
                        attrs.put("objectclass", "locality");
                    } else if (rdn.getType().equalsIgnoreCase("st")) {
                        attrs.put("objectclass", "locality");
                    } else if (rdn.getType().equalsIgnoreCase("c")) {
                        attrs.put("objectclass", "country");
                    } else if (rdn.getType().equalsIgnoreCase("dc")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "domain");
                        attrs.put("dc", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("displayName")) {//use displayName to replace dnqualifier
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "inetOrgPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("givenname")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "inetOrgPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("initials")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "inetOrgPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("sn")) {
                        attrs.put("objectclass", "person");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("street")) {
                        attrs.put("objectclass", "locality");
                    } else if (rdn.getType().equalsIgnoreCase("title")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "organizationalPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("telephonenumber")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "organizationalPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("uid")) {
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "inetOrgPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("mail")) {//use mail to replace e
                        attrs.put("objectclass", "top");
                        attrs.put("objectclass", "inetOrgPerson");
                        attrs.put("sn", rdn.getValue());
                        attrs.put("cn", rdn.getValue());
                    } else if (rdn.getType().equalsIgnoreCase("cn")) {
                        attrs.put("objectclass", "person");
                        attrs.put("sn", rdn.getValue());
                    } else {
                        logger.error("该rdn类型未定义  " + rdn.getType());
                        throw new Exception("unknown this type " + rdn.getType());
                    }
                    ldapTemplate.bind(getdn, null, attrs);
                }
            }
            rdn = ldapName.getRdn(posn + 1);
            getdn = rdn + "," + getdn;
        }

    }

    private boolean checkExist(String dn) {
        boolean exist = false;
        Object old;
        try {
            old = ldapTemplate.lookup(dn);
            System.out.println(old);
            if (old != null) {
                exist = true;
            }
        } catch (Exception e) {
        }
        return exist;
    }


    private X509CertificatePair selectCrossCert(String dn) {
        logger.debug("开始查询");
        List<byte[]> list;
        try {
            list = ldapTemplate.search(
                    query().base(dn).where("objectclass").is("pkiCA"),
                    (AttributesMapper<byte[]>) attrs -> {
                        System.out.println(attrs);
                        return (byte[]) attrs.get("crosscertificatepair;binary").get();
                    }
            );
        } catch (Exception e) {
            logger.debug("此节点下查不到");
            return null;
        }

        for (byte[] b : list) {
            X509CertPairParser parser = new X509CertPairParser();
            parser.engineInit(new ByteArrayInputStream(b));
            try {
                org.bouncycastle.x509.X509CertificatePair pair = (org.bouncycastle.x509.X509CertificatePair) parser.engineRead();
                logger.debug("查到了");
                X509CertificatePair p = new X509CertificatePair(pair.getForward(), pair.getReverse());
                return p;
            } catch (Exception e) {
                return null;
            }
        }
        return null;
    }

}

