java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > mybatis拦截器数据库数据权限隔离

mybatis拦截器实现数据库数据权限隔离方式

作者:不要停下脚步

通过Mybatis拦截器,在执行SQL前添加条件实现数据权限隔离,特别是对于存在用户ID区分的表,拦截器会自动添加如user_id=#{userId}的条件,确保SQL在执行时只能操作指定用户的数据,此方法主要应用于Mybatis的四个阶段

原理

使用拦截器在mybatis 执行sql 之前 ,

将sql 后面加上指定的查询条件 

比如,你的表以user_id 作为区分 

那么你就需要在sql 拦截器中加上 user_id = #{userId} 的逻辑

实现

mybatis 拦截器的相关知识不再赘述 , 可以在mybatis 的四个阶段进行拦截

分别是 Execute , MappedStatment , ParamHanlder ,以及 ResultHandler

详细的每个阶段做什么事情 ,可以自行百度。

 @AuthFilter(userFiled = "user_id" , ignoreOrgFiled = true)
    Page getUserMsgPage(@Param("page")Page page , @Param("param") MsgUserRefDto param , @Param("loginId") String loginId , @Param("orderBy")String orderBy);

具体效果就是 , 我们希望上面的sql 在执行的时候 ,自动拼接上 and user_id = 1 ,去过滤指定用户的数据。

配置文件

@Configuration
@AutoConfigureAfter(PageHelperAutoConfiguration.class)
public class MybatisConfig {

    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    @PostConstruct
    void mybatisConfigurationCustomizer() {

        AuthInterceptor authInterceptor = new AuthInterceptor();
        sqlSessionFactoryList.forEach(o->{
            o.getConfiguration().addInterceptor(authInterceptor);
        });
    }
}

自定义注解

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD , ElementType.TYPE})
@Documented
public @interface AuthFilter {

    String userFiled() default "userId";

    String orgFiled() default "orgId";

    boolean ignoreUserFiled() default false;

    boolean ignoreOrgFiled() default false;
}

具体拦截器逻辑

其中,GlobalHolder 就是每个系统中自己存储用户登录信息的容器 。

@Slf4j
@Component
@Intercepts({@Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
public class AuthInterceptor implements Interceptor {

   private static final Map<Class<?>, Map<String, List<List<Class>>>> mapperCache = new ConcurrentHashMap();


    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object[] args = invocation.getArgs();
        String id = ((MappedStatement)args[0]).getId();
        String clazzName = id.substring(0, id.lastIndexOf('.'));
        String mapperMethod = id.substring(id.lastIndexOf('.') + 1);

        Object[] paramArr = getParamArr(args[1]);
        Class<?> clazz = Class.forName(clazzName);

        Method method = getMethod(clazz, mapperMethod, paramArr);
        AuthFilter authFilter = method.getAnnotation(AuthFilter.class);


        // 如果方法没有加上注解正常执行 ,否则开始解析
        if (authFilter != null) {

            Map params = new HashMap();
            // 获取各个filed
            String orgFiled = authFilter.orgFiled();
            String userFiled = authFilter.userFiled();
            // 获取用户登录id 和 组织Id
            String orgId = GlobalHolder.getOrgId();
            String loginId = GlobalHolder.getLoginId();

            boolean ignoreOrgFiled = authFilter.ignoreOrgFiled();
            boolean ignoreUserFiled = authFilter.ignoreUserFiled();

            MappedStatement ms = (MappedStatement)args[0];
            Object parameter = args[1];
            BoundSql boundSql;
            if (args.length == 4) {
                boundSql = ms.getBoundSql(parameter);
            } else {
                boundSql = (BoundSql)args[5];
            }

            String sql = boundSql.getSql();

            // 添加组织编号
            if (!ignoreOrgFiled) {

                if(StringUtils.isNotEmpty(orgId)){
                    params.put(orgFiled , orgId);
                }else {
                    throw new IllegalStateException("用户未登录!");
                }

            }

            if (!ignoreUserFiled) {

                if(StringUtils.isNotEmpty(loginId)){
                    params.put(userFiled , loginId);
                }else {
                    throw new IllegalStateException("用户未登录!");
                }
            }

            if(params.size() > 0){
               String concatSql = contactConditions(wrapSql(sql) , params);
                log.info("添加后的sql为: {}" , concatSql);
                ReflectUtil.setFieldValue(boundSql, "sql", concatSql);
            }
        }
        return invocation.proceed();
    }


    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);

    }

    @Override
    public void setProperties(Properties properties) {
    }

    private String wrapSql(String sql){

        if(StringUtils.isNotEmpty(sql)){

            StringBuilder realSql = new StringBuilder();
            realSql.append("select * from ( ");
            realSql.append(sql);
            realSql.append(") a");

            return realSql.toString();
        }
        return sql;
    }

    /** 获取 mapper 相应 Method 反射类 */
    private Method getMethod(Class<?> clazz, String mapperMethod, Object[] paramArr) throws NoSuchMethodException, NoSuchFieldException, IllegalAccessException {
        // 1、查 mapper 接口缓存
        if (!mapperCache.containsKey(clazz)) // mapper 没有缓存, 就进行缓存
        {
            cacheMapper(clazz);
        }
        // 2、返回相应 method
        A:
        for (List<Class> paramList : mapperCache.get(clazz).get(mapperMethod)) {
            if (!paramList.isEmpty()) {
                for (int i = 0; i < paramArr.length; i++) { // 比较参数列表class
                    if (paramArr[i] != null)
                        if (!compareClass(paramList.get(i), paramArr[i].getClass())) continue A;
                }
                return clazz.getMethod(mapperMethod, paramList.toArray(new Class[paramList.size()]));
            }
        }
        return clazz.getMethod(mapperMethod); // 返回无参方法
    }

        /** 对 mapper 方法字段进行缓存 */
        private void cacheMapper(Class<?> clazz) {
            Map<String, List<List<Class>>> methodMap = new HashMap();
            for(Method method : clazz.getMethods()) {
                List<List<Class>> paramLists = methodMap.containsKey(method.getName()) ?
                        methodMap.get(method.getName()) : new ArrayList<List<Class>>();
                List<Class> paramClass = new ArrayList<Class>();
                for (Type type : method.getParameterTypes())
                {
                    paramClass.add((Class) type);
                }
                paramLists.add(paramClass);
                methodMap.put(method.getName(), paramLists);
            }
            mapperCache.put(clazz, methodMap);
        }

        /** class 比较 */
        private boolean compareClass(Class<?> returnType, Class<?> paramType) throws NoSuchFieldException, IllegalAccessException {
            if(returnType == paramType) {
                return true;
            }
            else if(returnType.isAssignableFrom(paramType)) { // 判断 paramType 是否为 returnType 子类或者实现类
                return true;
            }
            // 基本数据类型判断
            else if(returnType.isPrimitive()) { // paramType为包装类
                return returnType == paramType.getField("TYPE").get(null);
            }
            else if(paramType.isPrimitive()) { // returnType为包装类
                return paramType == returnType.getField("TYPE").get(null);
            }
            return false;
        }

    /**
     * 获取 mybatis 中 mapper 接口的参数列表的参数值
     * @param parameter
     * @return
     */
    private Object[] getParamArr(Object parameter) {
        Object[] paramArr = null;
        // mapper 接口中使用的是 paramMap, 传多个参数
        if(parameter instanceof MapperMethod.ParamMap)
        {
            Map map = ((Map) parameter);
            if(!map.isEmpty()) {
                StringBuilder builder = new StringBuilder();
                // 初始化 param_arr
                int size = map.size() >> 1;
                paramArr = new Object[size];
                for(int i = 1;i <= size;i ++)
                {
                    // mapper 接口中使用 param0 ~ paramN 命名参数
                    paramArr[i - 1] = map.get(builder.append("param").append(i).toString());
                    builder.setLength(0);
                }
            }
        }
        else if(parameter != null)
        {
            paramArr = new Object[1];
            paramArr[0] = parameter;
        }
        return paramArr;
    }


    private static String contactConditions(String sql, Map<String, Object> columnMap) {
        SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, JdbcUtils.MYSQL);
        List<SQLStatement> stmtList = parser.parseStatementList();
        SQLStatement stmt = stmtList.get(0);
        if (stmt instanceof SQLSelectStatement) {
            StringBuffer constraintsBuffer = new StringBuffer();
            Set<String> keys = columnMap.keySet();
            Iterator<String> keyIter = keys.iterator();
            if (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(key).append(" = " + getSqlByClass(columnMap.get(key)));
            }
            while (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(" AND ").append(key).append(" = " + getSqlByClass(columnMap.get(key)));
            }
            SQLExprParser constraintsParser = SQLParserUtils.createExprParser(constraintsBuffer.toString(), JdbcUtils.MYSQL);
            SQLExpr constraintsExpr = constraintsParser.expr();

            SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
            // 拿到SQLSelect
            SQLSelect sqlselect = selectStmt.getSelect();
            SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
            SQLExpr whereExpr = query.getWhere();
            // 修改where表达式
            if (whereExpr == null) {
                query.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                query.setWhere(newWhereExpr);
            }
            sqlselect.setQuery(query);
            return sqlselect.toString();

        }

        return sql;
    }

    private static String getSqlByClass(Object value){

        if(value instanceof Number){
            return value + "";
        }else if(value instanceof String){
            return "'" + value + "'";
        }

        return "'" + value.toString() + "'";
    }

}

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文