本文基于mybatis插件机制,自定义实现了一个打印完整sql的插件。

主要功能

  • 替换sql中的参数占位符(?)为参数值
  • 统计sql执行时间
  • 格式化sql(去除多余的空格和换行)

一些不足

  • 自定义TyperHandler的参数无法很好的解析

代码实现

@Intercepts(value = {@Signature(args = {Statement.class, ResultHandler.class}, method = "query", type = StatementHandler.class),
        @Signature(args = {Statement.class}, method = "update", type = StatementHandler.class),
        @Signature(args = {Statement.class}, method = "batch", type = StatementHandler.class)})
public class PrintSqlPlugin implements Interceptor {
    private static final Log logger = LogFactory.getLog(PrintSqlPlugin.class);
    public static final DateTimeFormatter DEFAULT_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS").withZone(ZoneId.systemDefault());

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object target = invocation.getTarget();

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        StatementHandler statementHandler = (StatementHandler) target;
        try {
            return invocation.proceed();
        } finally {
            stopWatch.stop();
            Configuration configuration = getConfiguration(statementHandler);
            BoundSql boundSql = statementHandler.getBoundSql();
            printSql(boundSql, configuration, stopWatch);
        }

    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }


    private void printSql(BoundSql boundSql, Configuration configuration, StopWatch stopWatch) {
        String sql = formatSql(boundSql, configuration);
        if (Objects.nonNull(sql)) {
            logger.info("执行 SQL: {}, 执行耗时: {} ms", sql, stopWatch.getTotalTimeMillis());
        }
    }

    private String formatSql(BoundSql boundSql, Configuration configuration) {
        String sql = boundSql.getSql();
        if (StringUtils.isBlank(sql)) {
            return null;
        }
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        //美化sql
        sql = this.beautifySql(sql);
        if (Objects.isNull(parameterObject) || CollectionUtils.isEmpty(parameterMappings)) {
            return sql;
        }
        try {
            //获取参数值
            List<String> parameterValues = fetchParameterValues(boundSql, configuration);
            //替换参数
            sql = replaceParamPlaceholder(sql, parameterValues);
        } catch (Exception e) {
            logger.error("SQL打印插件生成SQL失败, sql = {}, error= {}", boundSql.getSql(), e.getCause());
            return null;
        }
        return sql;
    }

    private String beautifySql(String sql) {
        sql = sql.replaceAll("[\\s\n ]+", " ");
        return sql;
    }

    private List<String> fetchParameterValues(BoundSql boundSql, Configuration configuration) {
        return boundSql.getParameterMappings().stream()
                       .map(parameterMapping -> getParamValue(parameterMapping, boundSql, configuration))
                       .collect(Collectors.toList());
    }

    /**
     * @see org.apache.ibatis.scripting.defaults.DefaultParameterHandler
     * 
     * @param parameterMapping
     * @param boundSql
     * @param configuration
     * @return
     */
    private String getParamValue(ParameterMapping parameterMapping, BoundSql boundSql, Configuration configuration) {
        String paramValueStr = null;
        Object parameterObject = boundSql.getParameterObject();
        //参数非输出参数
        if (parameterMapping.getMode() != ParameterMode.OUT) {
            Object value;
            String propertyName = parameterMapping.getProperty();
            //是否包含额外参数:动态参数(如foreach参数)的参数名由MyBatis动态生成,相关信息存储到额外参数中
            if (boundSql.hasAdditionalParameter(propertyName)) {
                //从额外参数中获取值
                value = boundSql.getAdditionalParameter(propertyName);
            } else if (parameterObject == null) {
                value = null;
            } else if (configuration.getTypeHandlerRegistry().hasTypeHandler(parameterObject.getClass())) {
                value = parameterObject;
            } else {
                //获取元数据对象,访问Map/对象类型的属性值
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                value = metaObject.getValue(propertyName);
            }

            //获取typeHandler处理之后的值
            value = valueAfterTypeHandler(parameterMapping.getTypeHandler(), value);

            //转换为sql字面量
            paramValueStr = toSqlLiteral(value);
        }
        return paramValueStr;

    }

    private String replaceParamPlaceholder(String sql, List<String> parameterValues) {
        Matcher matcher = Pattern.compile("\\?").matcher(sql);
        StringBuffer stringBuffer = new StringBuffer();
        int index = 0;
        while (matcher.find()) {
            //转义参数值中的特殊字符如"\"或"$"
            String currentParameterVal = Matcher.quoteReplacement(parameterValues.get(index++));
            matcher.appendReplacement(stringBuffer, currentParameterVal);
        }
        matcher.appendTail(stringBuffer);
        return stringBuffer.toString();
    }

    private Object valueAfterTypeHandler(TypeHandler<?> typeHandler, Object value) {
        //TODO 指定了TypeHandler的类型对应的参数未处理
        return value;
    }


    private String toSqlLiteral(Object value) {
        String paramValueStr;
        if (value instanceof String) {
            paramValueStr = toSqlStrLiteral(value);
        } else if (value instanceof Date) {
            paramValueStr = toSqlStrLiteral(formatDate((Date) value));
        } else {
            paramValueStr = value.toString();
        }
        return paramValueStr;
    }

    private String toSqlStrLiteral(Object obj) {
        return String.format("'%s'", Optional.ofNullable(obj).map(String::valueOf).orElse(StringUtils.EMPTY));
    }

    private String formatDate(Date date) {
        return date.toInstant().atZone(ZoneId.systemDefault()).format(DEFAULT_FORMAT);
    }

    private Configuration getConfiguration(StatementHandler statementHandler) throws IllegalAccessException {
        final DefaultParameterHandler parameterHandler = (DefaultParameterHandler) statementHandler
                .getParameterHandler();
        Field configurationField = ReflectionUtils.findField(parameterHandler.getClass(), "configuration");
        ReflectionUtils.makeAccessible(configurationField);
        return (Configuration) configurationField.get(parameterHandler);
    }

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.