Skip to content

Commit

Permalink
optimize: When the number of primary keys exceeds 1000, use union to …
Browse files Browse the repository at this point in the history
…concatenate the SQL #6957 (#7012)
  • Loading branch information
remind authored Dec 25, 2024
1 parent 748f50e commit 340ea92
Show file tree
Hide file tree
Showing 18 changed files with 165 additions and 89 deletions.
2 changes: 2 additions & 0 deletions changes/en-us/2.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Add changes here for all PR submitted to the 2.x branch.
### optimize:

- [[#6828](https://github.com/apache/incubator-seata/pull/6828)] spring boot compatible with file.conf and registry.conf
- [[#7012](https://github.com/apache/incubator-seata/pull/7012)] When the number of primary keys exceeds 1000, use union to concatenate the SQL

### security:

Expand All @@ -32,5 +33,6 @@ Thanks to these contributors for their code commits. Please report an unintended

- [slievrly](https://github.com/slievrly)
- [lyl2008dsg](https://github.com/lyl2008dsg)
- [remind](https://github.com/remind)

Also, we receive many valuable issues, questions and advices from our community. Thanks for you all.
2 changes: 2 additions & 0 deletions changes/zh-cn/2.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### optimize:

- [[#6828](https://github.com/apache/incubator-seata/pull/6828)] seata-spring-boot-starter兼容file.conf和registry.conf
- [[#7012](https://github.com/apache/incubator-seata/pull/7012)] 当主键超过1000个时,使用union拼接sql,可以使用索引

### security:

Expand All @@ -32,5 +33,6 @@

- [slievrly](https://github.com/slievrly)
- [lyl2008dsg](https://github.com/lyl2008dsg)
- [remind](https://github.com/remind)

同时,我们收到了社区反馈的很多有价值的issue和建议,非常感谢大家。
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ResultSet executeQuery(MockStatementBase statement, String sql) throws SQ
List<Object[]> metas = new ArrayList<>();
if(asts.get(0) instanceof SQLSelectStatement) {
SQLSelectStatement ast = (SQLSelectStatement) asts.get(0);
SQLSelectQueryBlock queryBlock = ast.getSelect().getQueryBlock();
SQLSelectQueryBlock queryBlock = ast.getSelect().getFirstQueryBlock();
String tableName = "";
if (queryBlock.getFrom() instanceof SQLExprTableSource) {
MySQLSelectForUpdateRecognizer recognizer = new MySQLSelectForUpdateRecognizer(sql, ast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;

import org.apache.seata.rm.datasource.sql.struct.Field;
import org.apache.seata.sqlparser.util.ColumnUtils;
Expand All @@ -36,32 +38,51 @@ private SqlGenerateUtils() {

}

public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSize, String dbType)
throws SQLException {
return buildWhereConditionByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);

/**
* build full sql by pks.
* @param sqlPrefix sql prefix
* @param suffix sql suffix
* @param pkNameList pk column name list
* @param rowSize the row size of records
* @param dbType the type of database
* @return full sql
*/
public static String buildSQLByPKs(String sqlPrefix, String suffix, List<String> pkNameList, int rowSize, String dbType) {
List<WhereSql> whereList = buildWhereConditionListByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);
StringJoiner sqlJoiner = new StringJoiner(" UNION ");
whereList.forEach(whereSql -> sqlJoiner.add(sqlPrefix + " " + whereSql.getSql() + " " + suffix));
return sqlJoiner.toString();
}
/**
* each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?)
* ) or (id,userCode) in ((?,?))"
* each pk is a condition.the result will like :" [(id,userCode) in ((?,?),(?,?)), (id,userCode) in ((?,?),(?,?)
* ), (id,userCode) in ((?,?))]"
* Build where condition by pks string. size default MAX_IN_SIZE
*
* @param pkNameList pk column name list
* @param rowSize the row size of records
* @param dbType the type of database
* @return return where condition sql list.the sql can search all related records not just one.
*/
public static List<WhereSql> buildWhereConditionListByPKs(List<String> pkNameList, int rowSize, String dbType) {
return buildWhereConditionListByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);
}
/**
* each pk is a condition.the result will like :" [(id,userCode) in ((?,?),(?,?)), (id,userCode) in ((?,?),(?,?)
* ), (id,userCode) in ((?,?))]"
* Build where condition by pks string.
*
* @param pkNameList pk column name list
* @param rowSize the row size of records
* @param dbType the type of database
* @param maxInSize the max in size
* @return return where condition sql string.the sql can search all related records not just one.
* @throws SQLException the sql exception
* @return return where condition sql list.the sql can search all related records not just one.
*/
public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSize, String dbType, int maxInSize)
throws SQLException {
StringBuilder whereStr = new StringBuilder();
public static List<WhereSql> buildWhereConditionListByPKs(List<String> pkNameList, int rowSize, String dbType, int maxInSize) {
List<WhereSql> whereSqls = new ArrayList<>();
//we must consider the situation of composite primary key
int batchSize = rowSize % maxInSize == 0 ? rowSize / maxInSize : (rowSize / maxInSize) + 1;
for (int batch = 0; batch < batchSize; batch++) {
if (batch > 0) {
whereStr.append(" or ");
}
StringBuilder whereStr = new StringBuilder();
whereStr.append("(");
for (int i = 0; i < pkNameList.size(); i++) {
if (i > 0) {
Expand All @@ -88,9 +109,10 @@ public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSi
whereStr.append(")");
}
whereStr.append(" )");
whereSqls.add(new WhereSql(whereStr.toString(), eachSize, pkNameList.size()));
}

return whereStr.toString();
return whereSqls;
}

/**
Expand Down Expand Up @@ -135,4 +157,38 @@ public static String buildWhereConditionByPKs(List<String> pkNameList, String db
return whereStr.toString();
}

public static class WhereSql {
/**
* sql
*/
private final String sql;

/**
* row size
*/
private final int rowSize;

/**
* pk size
*/
private final int pkSize;

public WhereSql(String sql, int rowSize, int pkSize) {
this.sql = sql;
this.rowSize = rowSize;
this.pkSize = pkSize;
}

public String getSql() {
return sql;
}

public int getRowSize() {
return rowSize;
}

public int getPkSize() {
return pkSize;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,17 @@ protected TableRecords buildTableRecords(Map<String, List<Object>> pkValuesMap)
// build check sql
String firstKey = pkValuesMap.keySet().stream().findFirst().get();
int rowSize = pkValuesMap.get(firstKey).size();
suffix.append(WHERE).append(SqlGenerateUtils.buildWhereConditionByPKs(pkColumnNameList, rowSize, getDbType()));
suffix.append(WHERE);
StringJoiner selectSQLJoin = new StringJoiner(", ", prefix, suffix.toString());
List<String> insertColumnsUnEscape = recognizer.getInsertColumnsUnEscape();
List<String> needColumns =
getNeedColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), insertColumnsUnEscape);
needColumns.forEach(selectSQLJoin::add);
PreparedStatement ps = null;
String sqlStr = SqlGenerateUtils.buildSQLByPKs(selectSQLJoin.toString(), "", pkColumnNameList, rowSize, getDbType());
ResultSet rs = null;
try {
ps = statementProxy.getConnection().prepareStatement(selectSQLJoin.toString());
ps = statementProxy.getConnection().prepareStatement(sqlStr);
int paramIndex = 1;
for (int r = 0; r < rowSize; r++) {
for (int c = 0; c < pkColumnNameList.size(); c++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,8 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
SQLUpdateRecognizer sqlUpdateRecognizer = (SQLUpdateRecognizer) sqlRecognizer;
updateColumnsSet.addAll(sqlUpdateRecognizer.getUpdateColumnsUnEscape());
}
StringBuilder prefix = new StringBuilder("SELECT ");
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
StringJoiner selectSQLJoiner = new StringJoiner(", ", "SELECT ",
" FROM " + getFromTableInSQL() + " WHERE ");
if (ONLY_CARE_UPDATE_COLUMNS) {
if (!containsPK(new ArrayList<>(updateColumnsSet))) {
selectSQLJoiner.add(getColumnNamesInSQL(tableMeta.getEscapePkNameList(getDbType())));
Expand All @@ -162,7 +161,7 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
selectSQLJoiner.add(ColumnUtils.addEscape(columnName, getDbType()));
}
}
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), "", tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
}

protected String buildSuffixSql(String whereCondition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,12 @@ protected TableRecords afterImage(TableRecords beforeImage) throws SQLException
}

private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage) throws SQLException {
String prefix = "SELECT ";
String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + whereSql;
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix, suffix);
StringJoiner selectSQLJoiner = new StringJoiner(", ", "SELECT "
, " FROM " + getFromTableInSQL() + " WHERE ");
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
List<String> needUpdateColumns = getNeedColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsUnEscape());
needUpdateColumns.forEach(selectSQLJoiner::add);
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), "", tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,12 @@ private String buildAfterImageSQL(String joinTable, String itemTable,
TableRecords beforeImage) throws SQLException {
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
TableMeta itemTableMeta = getTableMeta(itemTable);
StringBuilder prefix = new StringBuilder("SELECT ");
List<String> pkColumns = getColumnNamesWithTablePrefixList(itemTable, recognizer.getTableAlias(itemTable), itemTableMeta.getPrimaryKeyOnlyName());
String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(pkColumns, beforeImage.pkRows().size(), getDbType());
String suffix = " FROM " + joinTable + " WHERE " + whereSql;
//maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
suffix += GROUP_BY;
List<String> itemTableUpdateColumns = getItemUpdateColumns(itemTableMeta, recognizer.getUpdateColumns());
List<String> needUpdateColumns = getNeedColumns(itemTable, recognizer.getTableAlias(itemTable), itemTableUpdateColumns);
suffix += buildGroupBy(pkColumns, needUpdateColumns);
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
StringJoiner selectSQLJoiner = new StringJoiner(", ", "SELECT ", " FROM " + joinTable + " WHERE ");
needUpdateColumns.forEach(selectSQLJoiner::add);
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), GROUP_BY + buildGroupBy(pkColumns, needUpdateColumns), pkColumns, beforeImage.pkRows().size(), getDbType());
}

private List<String> getItemUpdateColumns(TableMeta itemTableMeta, List<String> updateColumns) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,30 +309,32 @@ protected TableRecords queryCurrentRecords(ConnectionProxy connectionProxy) thro
// build check sql
String firstKey = pkRowValues.keySet().stream().findFirst().get();
int pkRowSize = pkRowValues.get(firstKey).size();
String checkSQL = buildCheckSql(sqlUndoLog.getTableName(),
SqlGenerateUtils.buildWhereConditionByPKs(pkNameList, pkRowSize, connectionProxy.getDbType()));

PreparedStatement statement = null;
ResultSet checkSet = null;
TableRecords currentRecords;
try {
statement = conn.prepareStatement(checkSQL);
int paramIndex = 1;
int rowSize = pkRowValues.get(pkNameList.get(0)).size();
for (int r = 0; r < rowSize; r++) {
for (int c = 0; c < pkNameList.size(); c++) {
List<Field> pkColumnValueList = pkRowValues.get(pkNameList.get(c));
Field field = pkColumnValueList.get(r);
int dataType = tableMeta.getColumnMeta(field.getName()).getDataType();
statement.setObject(paramIndex, field.getValue(), dataType);
paramIndex++;
List<SqlGenerateUtils.WhereSql> sqlConditions = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, pkRowSize, connectionProxy.getDbType());
TableRecords currentRecords = new TableRecords(tableMeta);
int totalRowIndex = 0;
for (SqlGenerateUtils.WhereSql sqlCondition : sqlConditions) {
String checkSQL = buildCheckSql(sqlUndoLog.getTableName(), sqlCondition.getSql());
PreparedStatement statement = null;
ResultSet checkSet = null;
try {
statement = conn.prepareStatement(checkSQL);
int paramIndex = 1;
for (int r = 0; r < sqlCondition.getRowSize(); r++) {
for (int c = 0; c < sqlCondition.getPkSize(); c++) {
List<Field> pkColumnValueList = pkRowValues.get(pkNameList.get(c));
Field field = pkColumnValueList.get(totalRowIndex + r);
int dataType = tableMeta.getColumnMeta(field.getName()).getDataType();
statement.setObject(paramIndex, field.getValue(), dataType);
paramIndex++;
}
}
}
totalRowIndex += sqlCondition.getRowSize();

checkSet = statement.executeQuery();
currentRecords = TableRecords.buildRecords(tableMeta, checkSet);
} finally {
IOUtil.close(checkSet, statement);
checkSet = statement.executeQuery();
currentRecords.getRows().addAll(TableRecords.buildRecords(tableMeta, checkSet).getRows());
} finally {
IOUtil.close(checkSet, statement);
}
}
return currentRecords;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,49 @@
*/
package org.apache.seata.rm.datasource;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

import org.apache.seata.rm.datasource.SqlGenerateUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;


class SqlGenerateUtilsTest {


@Test
void testBuildWhereConditionByPKs() throws SQLException {
List<String> pkNameList=new ArrayList<>();
void testBuildWhereConditionListByPKs() {
List<String> pkNameList = new ArrayList<>();
pkNameList.add("id");
pkNameList.add("name");
List<SqlGenerateUtils.WhereSql> results1 = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 4, "mysql", 2);
Assertions.assertEquals(2, results1.size());
results1.forEach(result -> {
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", result.getSql());
Assertions.assertEquals(2, result.getRowSize());
Assertions.assertEquals(2, result.getPkSize());
});
List<SqlGenerateUtils.WhereSql> results2 = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 5, "mysql", 2);
Assertions.assertEquals(3, results2.size());
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", results2.get(0).getSql());
Assertions.assertEquals(2, results2.get(0).getRowSize());
Assertions.assertEquals(2, results2.get(0).getPkSize());
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", results2.get(1).getSql());
Assertions.assertEquals("(id,name) in ( (?,?) )", results2.get(2).getSql());
Assertions.assertEquals(1, results2.get(2).getRowSize());
Assertions.assertEquals(2, results2.get(2).getPkSize());
}

@Test
void testBuildSQLByPKs() {
String sqlPrefix = "select id,name from t_order where ";
List<String> pkNameList = new ArrayList<>();
pkNameList.add("id");
pkNameList.add("name");
String result = SqlGenerateUtils.buildWhereConditionByPKs(pkNameList,4,"mysql",2);
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?),(?,?) )", result);
result = SqlGenerateUtils.buildWhereConditionByPKs(pkNameList,5,"mysql",2);
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?)"
+ " )",
result);
List<SqlGenerateUtils.WhereSql> whereList = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 4, "mysql", 2);
StringJoiner sqlJoiner = new StringJoiner(" union ");
whereList.forEach(whereSql -> sqlJoiner.add(sqlPrefix + " " + whereSql.getSql()));
Assertions.assertEquals("select id,name from t_order where (id,name) in ( (?,?),(?,?) ) union select id,name from t_order where (id,name) in ( (?,?),(?,?) )", sqlJoiner.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.seata.rm.datasource.DataSourceProxy;
import org.apache.seata.rm.datasource.DataSourceProxyTest;
import org.apache.seata.rm.datasource.StatementProxy;
import org.apache.seata.rm.datasource.exec.UpdateExecutor;
import org.apache.seata.rm.datasource.exec.mysql.MySQLUpdateJoinExecutor;
import org.apache.seata.rm.datasource.mock.MockDriver;
import org.apache.seata.rm.datasource.sql.struct.TableRecords;
Expand All @@ -58,6 +57,7 @@ public void testUpdateJoinUndoLog() throws SQLException {
};
Object[][] beforeReturnValue = new Object[][]{
new Object[]{1, "Tom"},
new Object[]{2, "Tony"},
};
StatementProxy beforeMockStatementProxy = mockStatementProxy(returnValueColumnLabels, beforeReturnValue, columnMetas, indexMetas);
String sql = "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL'";
Expand All @@ -69,6 +69,7 @@ public void testUpdateJoinUndoLog() throws SQLException {
TableRecords beforeImage = mySQLUpdateJoinExecutor.beforeImage();
Object[][] afterReturnValue = new Object[][]{
new Object[]{1, "WILL"},
new Object[]{2, "Tony"},
};
StatementProxy afterMockStatementProxy = mockStatementProxy(returnValueColumnLabels, afterReturnValue, columnMetas, indexMetas);
mySQLUpdateJoinExecutor.statementProxy = afterMockStatementProxy;
Expand Down
Loading

0 comments on commit 340ea92

Please sign in to comment.