mybatis拦截器,给所有的mybatis plus中的条件构造器对象添加一个条件。
由于数据库一开始没有设置字段项目Id,后面再添加的项目id,代码又写了很多,一个个去添加修改,不太现实,容易出现错漏,于是添加如下拦截器。
想要能够使用该拦截器,必须DAO实体中有指定参数,我指定的参数就是项目id,
还需要Mapper extends BaseMapper,mapper继承了之后才有用,再到使用LambdaQueryWrapper,必须是走条件构造器(这里可以自己改动代码,变成自己写的sql也能进来),这时候你的查询sql就会默认进入这个拦截器(我这里只处理的查询,外面还判断了只有查询才能进入)。
这个项目id是从请求头中获取,放入了ThreadLocal中,如果做法不同的话,你就需要自己想办法获取到这个参数信息,并填入119行了。
其实代码本质上就是拦截所有的sql,判断是查询,就往下走,反射获取实体类,判断有没有对应的参数,如果有,继续往下走,判断是否能获取到要额外添加的查询参数,如果有就继续,将这个参数,塞入对应的sql中。
package com.bs.common.config;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.springframework.stereotype.Component;
import java.io.StringReader;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.util.List;
/**
* mybatis拦截器
*
* @author 汪彪
* @date 2023年08月02日
*/
@Slf4j
@AllArgsConstructor
@Intercepts({
@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Component
public class MybatisProjectInterceptor implements Interceptor {
private static final String PROJECT_ID = "projectId";
public static final String PROJECT_ID_DATA_PARAM = "project_id";
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//跳过非查询接口,只处理查询的接口
if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
return invocation.proceed();
}
// 获取到mapper的信息
String resource = mappedStatement.getResource();
//判断是否是单表查询
if (resource.contains(".java")) {
return handleMybatisPlusSql(resource, invocation , metaObject);
}
return invocation.proceed();
}
/**
* 处理单表查询的sql语句
*
* @param resource
* @param invocation
* @param metaObject
* @author 汪彪
* @date 2023/8/7 14:19
**/
private Object handleMybatisPlusSql(String resource, Invocation invocation, MetaObject metaObject) throws Exception {
String substring = resource.substring(0, resource.indexOf("."));
String replace = substring.replace("/", ".");
Class<?> mapperClass = Class.forName(replace);
if (BaseMapper.class.isAssignableFrom(mapperClass)) {
//获取带泛型的父接口
Type[] interfaces = mapperClass.getGenericInterfaces();
Type anInterface = interfaces[0];
ParameterizedType paramType = (ParameterizedType) anInterface;
//获取父接口的泛型
Type[] actualTypeArguments = paramType.getActualTypeArguments();
Class<?> targetClass = (Class<?>) actualTypeArguments[0];
//判断参数是否包含projectId
Field projectId = null;
try {
projectId = targetClass.getDeclaredField(PROJECT_ID);
} catch (Exception ignored) {
}
if (projectId == null) {
return invocation.proceed();
}
//特殊处理SysSubUser
if (targetClass.getName().contains("SysSubUser")||targetClass.getName().contains("EngWarningPushRecord")||targetClass.getName().contains("EngWarningPush")){
return invocation.proceed();
}
//获取项目id
Integer projectIdValue = null;
try {
projectIdValue = GengraLcontextHolder.getProjectId();
} catch (Exception e) {
return invocation.proceed();
}
//给sql添加条件
String originalSql = (String) metaObject.getValue("delegate.boundSql.sql");
String sql = handleSql(originalSql, PROJECT_ID_DATA_PARAM, projectIdValue.toString());
metaObject.setValue("delegate.boundSql.sql", sql);
} else {
//父接口不是baseMapper,并非使用的mybatis-plus
return invocation.proceed();
}
return invocation.proceed();
}
/**
* 处理sql语句
*
* @param originalSql
* @param key
* @param value
* @date 2023/8/7 14:17
**/
private String handleSql(String originalSql, String key, String value) throws JSQLParserException {
CCJSqlParserManager parserManager = new CCJSqlParserManager();
//处理select语句
Select select = (Select) parserManager.parse(new StringReader(originalSql));
SelectBody selectBody = select.getSelectBody();
if (selectBody instanceof PlainSelect) {
this.setWhere((PlainSelect) selectBody, key, value);
} else if (selectBody instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = setOperationList.getSelects();
selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, key, value));
}
return select.toString();
}
/**
* 处理查询条件
*
* @param plainSelect
* @param key
* @param value
* @date 2023/8/7 14:16
**/
@SneakyThrows(Exception.class)
protected void setWhere(PlainSelect plainSelect, String key, String value) {
Table fromItem = (Table) plainSelect.getFromItem();
// 有别名用别名,无别名用表名,防止字段冲突报错
Alias fromItemAlias = fromItem.getAlias();
String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
// 构建子查询 -- 数据权限过滤SQL
String dataPermissionSql = "";
EqualsTo selfEqualsTo = new EqualsTo();
selfEqualsTo.setLeftExpression(new Column(mainTableName + "." + key));
selfEqualsTo.setRightExpression(new StringValue(value));
dataPermissionSql = selfEqualsTo.toString();
//添加where语句
if (plainSelect.getWhere() == null) {
plainSelect.setWhere(CCJSqlParserUtil.parseCondExpression(dataPermissionSql));
} else {
plainSelect.setWhere(new AndExpression(plainSelect.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
}
}
}