package com.xdja.base.common.dao;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.orm.hibernate3.LocalSessionFactoryBean;
import org.springframework.stereotype.Repository;
import org.springframework.util.Assert;

import com.xdja.base.system.ServerRuntimeException;
import com.xdja.base.util.page.Pagination;

/**
 * 
 * @ClassName：BaseDao
 * @Description： 基类dao
 * @author: mayanpei
 * @date: 2014-5-9 9:20:18
 * 
 */
@Repository
public class BaseDao implements InitializingBean {

	protected Logger logger = LoggerFactory.getLogger(getClass());

	@Autowired
	private LocalSessionFactoryBean sessionFactory;

	@Autowired
	protected JdbcTemplate jdbcTemplate;

	@Autowired
	protected NamedParameterJdbcTemplate namedJdbcTemplate;

	private DatabaseName databaseName;

	// 数据库类型
	private enum DatabaseName {
		MYSQL, ORACLE
	}

	public void afterPropertiesSet() {
		Connection conn = null;
		try {
			// 获取数据库类型
			conn = sessionFactory.getDataSource().getConnection();
			databaseName = DatabaseName.valueOf(conn.getMetaData().getDatabaseProductName().toUpperCase());
			logger.debug("当前系统使用的数据库为：{}", databaseName);
		} catch (Throwable e) {
			logger.error("判断系统使用的数据库类型出错", e);
		} finally {
			if (conn != null) {
				try {
					conn.close();
				} catch (SQLException e) {
					// 什么也不做
				}
			}
		}
		Assert.notNull(databaseName, "未知的数据库类型");
	}

	/**
	 * 
	 * Description：update data use namedJdbcTemplate
	 * 
	 * @author:mayanpei
	 * @date: 2013-10-23
	 * @param sql
	 * @param sqlParam
	 * @return
	 */
	public int update(String sql, SqlParameterSource sqlParam) {
		return namedJdbcTemplate.update(sql, sqlParam);
	}

	/**
	 * 
	 * @Title: queryForPage
	 * @Description: 分页查询
	 * @param sql
	 * @param pageSize
	 *            每页条数
	 * @param pageNo
	 *            当前页码
	 * @param paramSource
	 *            查询条件参数
	 * @return Pagination 分页信息，其中包含有数据列表:List<Map<String, Object>>
	 */
	public Pagination queryForPage(String sql, Integer pageSize, Integer pageNo, MapSqlParameterSource paramSource) {
		switch (databaseName) {
		case MYSQL:
			return queryForPageMysql(sql, pageSize, pageNo, paramSource);
		case ORACLE:
			return queryForPageOracle(sql, pageSize, pageNo, paramSource);
		default:
			throw new ServerRuntimeException("不支持的数据库类型：" + databaseName);
		}
	}

	private Pagination queryForPageMysql(String sql, Integer pageSize, Integer pageNo, MapSqlParameterSource paramSource) {
		int totalCount = queryForObject(getRowCountSql(sql), paramSource, Integer.class);
		Pagination pagination = new Pagination(pageNo, pageSize, totalCount);
		if (totalCount < 1) {
			pagination.setList(Collections.EMPTY_LIST);
			return pagination;
		}

		StringBuilder stringBuilder = new StringBuilder(sql);
		stringBuilder.append(" limit :offset, :rows");
		if (paramSource == null) {
			paramSource = new MapSqlParameterSource();
		}
		paramSource.addValue("offset", pagination.getPageSize() * (pagination.getPageNo() - 1));
		paramSource.addValue("rows", pagination.getPageSize());

		pagination.setList(queryForList(stringBuilder.toString(), paramSource));
		return pagination;
	}

	private Pagination queryForPageOracle(String sql, Integer pageSize, Integer pageNo,
			MapSqlParameterSource paramSource) {
		int totalCount = queryForObject(getRowCountSql(sql), paramSource, Integer.class);
		Pagination pagination = new Pagination(pageNo, pageSize, totalCount);
		if (totalCount < 1) {
			pagination.setList(Collections.EMPTY_LIST);
			return pagination;
		}

		StringBuilder stringBuilder = new StringBuilder(sql);
		stringBuilder.append("SELECT * FROM (SELECT pagedTable.*, ROWNUM AS myRownum FROM (").append(sql)
				.append(") pagedTable WHERE ROWNUM<= :rows ) WHERE myRownum>= :offset");
		if (paramSource == null) {
			paramSource = new MapSqlParameterSource();
		}
		paramSource
				.addValue("rows", pagination.getPageSize() * (pagination.getPageNo() - 1) + pagination.getPageSize());
		paramSource.addValue("offset", pagination.getPageSize() * (pagination.getPageNo() - 1) + 1);
		pagination.setList(queryForList(stringBuilder.toString(), paramSource));
		return pagination;
	}

	public String getRowCountSql(String sql) {
		StringBuilder stringBuilder = new StringBuilder("SELECT count(*) ");
		String upperSql = sql.toUpperCase();

		// 复杂查询判断
		int groupIndex = upperSql.indexOf(" GROUP BY ");
		boolean multiFrom = upperSql.indexOf("FROM") != upperSql.lastIndexOf("FROM");
		boolean multiOrder = upperSql.indexOf(" ORDER BY ") != upperSql.lastIndexOf(" ORDER BY ");

		if (groupIndex > 0 || multiFrom || multiOrder) {
			stringBuilder.append(" FROM (").append(sql).append(") result");
		} else {
			int fromIndex = upperSql.indexOf("FROM");
			String rowCountSql = sql.substring(fromIndex);

			int index = rowCountSql.toUpperCase().indexOf(" ORDER BY ");
			if (index > 0) {
				rowCountSql = rowCountSql.substring(0, index);
			}
			stringBuilder.append(rowCountSql);
		}
		return stringBuilder.toString();
	}

	/**
	 * 
	 * @Title: queryForPage
	 * @Description: 分页查询
	 * @param sql
	 * @param pageSize
	 *            每页条数
	 * @param pageNo
	 *            当前页码
	 * @param paramSource
	 *            查询条件参数
	 * @param rowMapper
	 *            RowMapper
	 * @return Pagination 分页信息，其中包含有数据列表:List<T>
	 */
	public Pagination queryForPage(String sql, Integer pageSize, Integer pageNo, MapSqlParameterSource paramSource,
			RowMapper<?> rowMapper) {
		switch (databaseName) {
		case MYSQL:
			return queryForPageMysql(sql, pageSize, pageNo, paramSource, rowMapper);
		case ORACLE:
			return queryForPageOracle(sql, pageSize, pageNo, paramSource, rowMapper);
		default:
			throw new ServerRuntimeException("不支持的数据库类型：" + databaseName);
		}
	}

	private Pagination queryForPageMysql(String sql, Integer pageSize, Integer pageNo,
			MapSqlParameterSource paramSource, RowMapper<?> rowMapper) {
		int totalCount = queryForObject(getRowCountSql(sql), paramSource, Integer.class);
		Pagination pagination = new Pagination(pageNo, pageSize, totalCount);
		if (totalCount < 1) {
			pagination.setList(Collections.EMPTY_LIST);
			return pagination;
		}

		StringBuilder stringBuilder = new StringBuilder(sql);
		stringBuilder.append(" limit :offset, :rows");
		if (paramSource == null) {
			paramSource = new MapSqlParameterSource();
		}
		paramSource.addValue("offset", pagination.getPageSize() * (pagination.getPageNo() - 1));
		paramSource.addValue("rows", pagination.getPageSize() * pagination.getPageNo());

		pagination.setList(query(stringBuilder.toString(), paramSource, rowMapper));
		return pagination;
	}

	private Pagination queryForPageOracle(String sql, Integer pageSize, Integer pageNo,
			MapSqlParameterSource paramSource, RowMapper<?> rowMapper) {
		int totalCount = queryForObject(getRowCountSql(sql), paramSource, Integer.class);
		Pagination pagination = new Pagination(pageNo, pageSize, totalCount);
		if (totalCount < 1) {
			pagination.setList(Collections.EMPTY_LIST);
			return pagination;
		}

		StringBuilder stringBuilder = new StringBuilder(sql);
		stringBuilder.append("SELECT * FROM (SELECT pagedTable.*, ROWNUM AS myRownum FROM (").append(sql)
				.append(") pagedTable WHERE ROWNUM<= :rows ) WHERE myRownum>= :offset");
		if (paramSource == null) {
			paramSource = new MapSqlParameterSource();
		}
		paramSource
				.addValue("rows", pagination.getPageSize() * (pagination.getPageNo() - 1) + pagination.getPageSize());
		paramSource.addValue("offset", pagination.getPageSize() * (pagination.getPageNo() - 1) + 1);
		pagination.setList(query(stringBuilder.toString(), paramSource, rowMapper));
		return pagination;
	}

	public <T> List<T> query(String sql, SqlParameterSource sqlParam, RowMapper<T> rowMapper) {
		return namedJdbcTemplate.query(sql, sqlParam, rowMapper);
	}

	public <T> T queryForObject(String sql, SqlParameterSource sqlParam, Class<T> clazz) {
		return namedJdbcTemplate.queryForObject(sql, sqlParam, clazz);
	}

	/**
	 * 
	 * 执行sql语句执行数据批量删除操作
	 * 
	 * @param sql
	 *            要执行的删除语句
	 * @param sqlParams
	 *            条件
	 * @return int[] 影响的记录数
	 */
	public int[] deleteBySql(String sql, SqlParameterSource... sqlParams) {
		if (ArrayUtils.isEmpty(sqlParams)) {
			sqlParams = (SqlParameterSource[]) ArrayUtils.add(sqlParams, null);
		}
		return namedJdbcTemplate.batchUpdate(sql, sqlParams);
	}

	/**
	 * 
	 * @Title: executeSql
	 * @Description: 执行指定的sql语句
	 * @param sql
	 * @param sqlParam
	 *            void
	 * @return int 影响的记录数
	 */
	public int executeSql(String sql, SqlParameterSource sqlParam) {
		return this.namedJdbcTemplate.update(sql, sqlParam);
	}

	/**
	 * 
	 * @Title: queryForObject
	 * @Description: 执行指定的SQL语句查询数据
	 * @param sql
	 * @param sqlParam
	 * @param mapper
	 * @return T
	 */
	protected <T> T queryForObject(String sql, SqlParameterSource sqlParam, RowMapper<T> mapper) {
		return namedJdbcTemplate.queryForObject(sql, sqlParam, mapper);
	}

	/**
	 * 
	 * @Description：执行指定的SQL语句查询数据
	 * @author: 任瑞修
	 * @date: 2013-11-6 上午11:40:38
	 * @param sql
	 * @param sqlParam
	 * @return
	 */
	public Map<String, Object> queryForMap(String sql, SqlParameterSource sqlParam) {
		return namedJdbcTemplate.queryForMap(sql, sqlParam);
	}

	/**
	 * 
	 * @Title: queryForInt
	 * @Description: 执行指定的SQL语句查询结果，多用于count等
	 * @param sql
	 * @param sqlParam
	 * @return int
	 */
	public int queryForInt(String sql, SqlParameterSource sqlParam) {
		return namedJdbcTemplate.queryForObject(sql, sqlParam, Integer.class);
	}

	/**
	 * 
	 * @Description：执行指定的SQL语句查询结果
	 * @author: mayanpei
	 * @date: 2013-11-12 上午8:45:07
	 * @param sql
	 * @param sqlParam
	 * @return long
	 */
	public long queryForLong(String sql, SqlParameterSource sqlParam) {
		return namedJdbcTemplate.queryForObject(sql, sqlParam, Long.class);
	}

	/**
	 * 
	 * @Title: queryForList
	 * @Description: 执行指定的SQL语句查询数据
	 * @param sql
	 *            SQL
	 * @param sqlParam
	 *            参数
	 * @return List<Map<String,Object>>
	 */
	public List<Map<String, Object>> queryForList(String sql, SqlParameterSource sqlParam) {
		return namedJdbcTemplate.queryForList(sql, sqlParam);
	}

	/**
	 * 
	 * @Description：执行指定的sql批量添加或更新数据
	 * @author: mayanpei
	 * @date: 2013-11-6 下午5:17:50
	 * @param sql
	 *            SQL
	 * @param sqlParams
	 *            参数
	 * @return int[] 影响的记录数
	 */
	public int[] addOrUpdate(String sql, SqlParameterSource... sqlParams) {
		if (ArrayUtils.isEmpty(sqlParams)) {
			sqlParams = (SqlParameterSource[]) ArrayUtils.add(sqlParams, null);
		}
		return namedJdbcTemplate.batchUpdate(sql, sqlParams);
	}
}
