Bootstrap

mybatis plus拦截器给查询语句添加额外的判断条件

mybatis拦截器,给所有的mybatis plus中的条件构造器对象添加一个条件。

由于数据库一开始没有设置字段项目Id,后面再添加的项目id,代码又写了很多,一个个去添加修改,不太现实,容易出现错漏,于是添加如下拦截器。

想要能够使用该拦截器,必须DAO实体中有指定参数,我指定的参数就是项目id,
还需要Mapper extends BaseMapper,mapper继承了之后才有用,再到使用LambdaQueryWrapper,必须是走条件构造器(这里可以自己改动代码,变成自己写的sql也能进来),这时候你的查询sql就会默认进入这个拦截器(我这里只处理的查询,外面还判断了只有查询才能进入)。

这个项目id是从请求头中获取,放入了ThreadLocal中,如果做法不同的话,你就需要自己想办法获取到这个参数信息,并填入119行了。

其实代码本质上就是拦截所有的sql,判断是查询,就往下走,反射获取实体类,判断有没有对应的参数,如果有,继续往下走,判断是否能获取到要额外添加的查询参数,如果有就继续,将这个参数,塞入对应的sql中。

package com.bs.common.config;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;

import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.stereotype.Component;

import java.io.StringReader;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.util.List;

/**
 * mybatis拦截器
 *
 * @author 汪彪
 * @date 2023年08月02日
 */
@Slf4j
@AllArgsConstructor
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Component
public class MybatisProjectInterceptor implements Interceptor {

    private static final String PROJECT_ID = "projectId";

    public static final String PROJECT_ID_DATA_PARAM = "project_id";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");

        //跳过非查询接口,只处理查询的接口
        if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
            return invocation.proceed();
        }

        // 获取到mapper的信息
        String resource = mappedStatement.getResource();
        //判断是否是单表查询
        if (resource.contains(".java")) {
            return handleMybatisPlusSql(resource, invocation , metaObject);
        }

        return invocation.proceed();
    }

    /**
     * 处理单表查询的sql语句
     *
     * @param resource
     * @param invocation
     * @param metaObject
     * @author 汪彪
     * @date 2023/8/7 14:19
     **/
    private Object handleMybatisPlusSql(String resource, Invocation invocation, MetaObject metaObject) throws Exception {
        String substring = resource.substring(0, resource.indexOf("."));
        String replace = substring.replace("/", ".");
        Class<?> mapperClass = Class.forName(replace);
        if (BaseMapper.class.isAssignableFrom(mapperClass)) {
            //获取带泛型的父接口
            Type[] interfaces = mapperClass.getGenericInterfaces();
            Type anInterface = interfaces[0];
            ParameterizedType paramType = (ParameterizedType) anInterface;

            //获取父接口的泛型
            Type[] actualTypeArguments = paramType.getActualTypeArguments();
            Class<?> targetClass = (Class<?>) actualTypeArguments[0];

            //判断参数是否包含projectId
            Field projectId = null;
            try {
                projectId = targetClass.getDeclaredField(PROJECT_ID);
            } catch (Exception ignored) {
            }
            if (projectId == null) {
                return invocation.proceed();
            }

            //特殊处理SysSubUser
            if (targetClass.getName().contains("SysSubUser")||targetClass.getName().contains("EngWarningPushRecord")||targetClass.getName().contains("EngWarningPush")){
                return invocation.proceed();
            }

            //获取项目id
            Integer projectIdValue = null;
            try {
                projectIdValue = GengraLcontextHolder.getProjectId();
            } catch (Exception e) {
                return invocation.proceed();
            }

            //给sql添加条件
            String originalSql = (String) metaObject.getValue("delegate.boundSql.sql");
            String sql = handleSql(originalSql, PROJECT_ID_DATA_PARAM, projectIdValue.toString());
            metaObject.setValue("delegate.boundSql.sql", sql);
        } else {
            //父接口不是baseMapper,并非使用的mybatis-plus
            return invocation.proceed();
        }
        return invocation.proceed();
    }

    /**
     * 处理sql语句
     *
     * @param originalSql
     * @param key
     * @param value
     * @date 2023/8/7 14:17
     **/
    private String handleSql(String originalSql, String key, String value) throws JSQLParserException {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        //处理select语句
        Select select = (Select) parserManager.parse(new StringReader(originalSql));
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, key, value);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, key, value));
        }
        return select.toString();
    }

    /**
     * 处理查询条件
     *
     * @param plainSelect
     * @param key
     * @param value
     * @date 2023/8/7 14:16
     **/
    @SneakyThrows(Exception.class)
    protected void setWhere(PlainSelect plainSelect, String key, String value) {
        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名,无别名用表名,防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
        // 构建子查询 -- 数据权限过滤SQL
        String dataPermissionSql = "";

        EqualsTo selfEqualsTo = new EqualsTo();
        selfEqualsTo.setLeftExpression(new Column(mainTableName + "." + key));
        selfEqualsTo.setRightExpression(new StringValue(value));
        dataPermissionSql = selfEqualsTo.toString();

        //添加where语句
        if (plainSelect.getWhere() == null) {
            plainSelect.setWhere(CCJSqlParserUtil.parseCondExpression(dataPermissionSql));
        } else {
            plainSelect.setWhere(new AndExpression(plainSelect.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
        }
    }
}

;