接上一篇:mybatis Interceptor拦截器实现自定义扩展查询兼容mybatis plus-xqlee (blog.xqlee.com)
这里进行自定义分页查询扩展,基于mybatis plus,同样适用于mybatis
mybatis (plus) 自定义分页拦截器
@Intercepts({
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class,Object.class, RowBounds.class, ResultHandler.class}
)
})
@Slf4j
public class MybatisPageSelectFilterInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
//这个可以得到当前执行的sql语句在xml文件中配置的id的值
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
Object parameter = invocation.getArgs()[1];
if (parameter instanceof Paging){
PageData<?> page=new PageData<>();
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
Paging<?> paging =(Paging<?>)boundSql.getParameterObject();
page.setCurrent(paging.getPageNum());
page.setSize(paging.getPageSize());
Connection connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();
long count=count(connection,boundSql,paging.getSelectFilter());
page.setTotal(count);
page.setPages(count%paging.getPageSize()==0?(count/ paging.getPageSize()):(count/ paging.getPageSize())+1);
if (paging.getPageNum()<=page.getPages()&&paging.getPageNum()>0){
//先处理统计
dealSelectFilter(mappedStatement,parameter,paging.getSelectFilter(),invocation,paging);
// 继续执行
List result = (List)invocation.proceed();
page.setRecords(result);
}else{
page.setRecords(new ArrayList<>());
}
return Arrays.asList(page);
}
// 继续执行
Object result = invocation.proceed();
return result;
}
public int count(Connection connection,BoundSql boundSql, SelectFilter selectFilter) {
Class entityClass = selectFilter.getEntityClass();
if (Objects.nonNull(entityClass)){
Field[] declaredFields = entityClass.getDeclaredFields();
for (Field declaredField : declaredFields) {
TableField tableField=(TableField) declaredField.getAnnotation(TableField.class);
if (Objects.nonNull(tableField)){
String dbName=StringUtils.isEmpty(tableField.value())?declaredField.getName():tableField.value();
if (Objects.nonNull(tableField)){
TableLogic tableLogic=(TableLogic) declaredField.getAnnotation(TableLogic.class);
if (Objects.nonNull(tableLogic)){
selectFilter.buildWithWhere();
String sql = selectFilter.getSql();
if (!sql.contains(dbName)){
selectFilter.eq(dbName,0);
}
}
}
}
}
}
PreparedStatement countStmt = null;
ResultSet rs = null;
selectFilter.buildWithWhere();
String sql = selectFilter.getSql();
String countSql=boundSql.getSql().replaceAll("select.*from","select count(0) from ")+sql;
log.info(countSql);
log.info(Arrays.toString(selectFilter.getSqlParamsMap().values().toArray()));
try {
countStmt = connection.prepareStatement(countSql);
Map sqlParamsMap = selectFilter.getSqlParamsMap();
if (!CollectionUtils.isEmpty(sqlParamsMap)){
int index = 1;//从1开始赋值
for (Object value : sqlParamsMap.values()) {
countStmt.setObject(index,value);
index++;
}
}
rs = countStmt.executeQuery();
if (rs.next()) {
return rs.getInt(1);
}
} catch (SQLException e) {
e.printStackTrace();
} finally {
try {
if (null != countStmt) {
countStmt.close();
}
if (null != rs) {
rs.close();
}
} catch (SQLException e) {
e.printStackTrace();
}
}
return 0;
}
private void dealSelectFilter(MappedStatement mappedStatement,Object parameter,SelectFilter selectFilter,Invocation invocation,Paging paging){
BoundSql boundSql = mappedStatement.getBoundSql(parameter);
String oldsql = boundSql.getSql();
log.info("old:"+oldsql);
Class entityClass = selectFilter.getEntityClass();
if (Objects.nonNull(entityClass)){
TableName tableNameAnnotation = (TableName) entityClass.getDeclaredAnnotation(TableName.class);
String tableName = tableNameAnnotation.value();
Matcher matcher = Pattern.compile("^select\\s+\\*\\s+from\\s+" + tableName).matcher(oldsql.toLowerCase());
if (matcher.find()){
StringBuffer sqlBuilder=new StringBuffer();
Field[] declaredFields = entityClass.getDeclaredFields();
sqlBuilder.append("select ");
List<String> fields=new ArrayList<>();
for (Field declaredField : declaredFields) {
String fieldName = declaredField.getName();
TableField tableField = declaredField.getAnnotation(TableField.class);
if (Objects.nonNull(tableField)&&tableField.exist()){
String name = declaredField.getName();
String dbName=StringUtils.isEmpty(tableField.value())?fieldName:tableField.value();
fields.add(dbName+" "+ name);
}
TableId tableId = declaredField.getAnnotation(TableId.class);
if (Objects.nonNull(tableId)){
fields.add(StringUtils.isEmpty(tableId.value())?fieldName: tableId.value()+" "+declaredField.getName());
}
}
String join = String.join(",", fields);
sqlBuilder.append(join);
sqlBuilder.append(" from ").append(tableName).append(" ");
oldsql=sqlBuilder.toString();
selectFilter.buildWithWhere();
}
}
selectFilter.buildWithWhere();
List<ParameterMapping> parameterMappingList=new ArrayList<>();
for (Object key : selectFilter.getSqlParamsMap().keySet()) {
parameterMappingList.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(),"selectFilter.sqlParamsMap."+key,Object.class).build());
}
String dealSql=oldsql+selectFilter.getSql()+" limit "+(paging.getPageNum()-1)* paging.getPageSize()+","+paging.getPageSize();
BoundSql newBoundSql = new BoundSql(mappedStatement.getConfiguration(), dealSql,
parameterMappingList, boundSql.getParameterObject());
MappedStatement newMs = copyFromMappedStatement(mappedStatement, new BoundSqlSqlSource(newBoundSql),selectFilter.getEntityClass());
invocation.getArgs()[0] = newMs;
}
/**
* 复制原始MappedStatement
* @param ms
* @param newSqlSource
* @return
*/
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource,Class<?> entityClass) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource,
ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null) {
for (String keyProperty : ms.getKeyProperties()) {
builder.keyProperty(keyProperty);
}
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
//设置返回列表类型为实体对象类型
ResultMap resultMap=new ResultMap.Builder(ms.getConfiguration(),ms.getId(), entityClass,new ArrayList<>()).build();
List<ResultMap> resultMaps=new ArrayList<>();
resultMaps.add(resultMap);
builder.resultMaps(resultMaps);
builder.cache(ms.getCache());
builder.useCache(ms.isUseCache());
return builder.build();
}
public static class BoundSqlSqlSource implements SqlSource {
BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
}
拦截器注册:
@Configuration
public class MybatisSelectFilterConfiguration {
@Autowired
private List<SqlSessionFactory> sqlSessionFactoryList;
@PostConstruct
public void addInterceptor() {
MybatisListSelectFilterInterceptor listSelectFilterInterceptor=new MybatisListSelectFilterInterceptor();
MybatisPageSelectFilterInterceptor pageSelectFilterInterceptor=new MybatisPageSelectFilterInterceptor();
sqlSessionFactoryList.forEach(sqlSessionFactory -> {
sqlSessionFactory.getConfiguration().addInterceptor(listSelectFilterInterceptor);
sqlSessionFactory.getConfiguration().addInterceptor(pageSelectFilterInterceptor);
});
}
}
mapper 代码:
@Select("select * from cy_xlxw")
PageData<Xlxw> page(Paging<Xlxw> paging);
http://blog.xqlee.com/article/1029.html