Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize: When the number of primary keys exceeds 1000, use union to concatenate the SQL #6957 #6987

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;

import io.seata.rm.datasource.sql.struct.Field;
import io.seata.sqlparser.util.ColumnUtils;
Expand All @@ -36,32 +38,53 @@ private SqlGenerateUtils() {

}

public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSize, String dbType)
throws SQLException {
return buildWhereConditionByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);
/**
* build sql by pks.
* @param sqlPrefix
* @param suffix
* @param pkNameList
* @param rowSize
* @param dbType
* @return
*/
public static String buildSQLByPKs(String sqlPrefix, String suffix, List<String> pkNameList, int rowSize, String dbType) {
List<WhereSql> whereList = buildWhereConditionListByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);
StringJoiner sqlJoiner = new StringJoiner(" UNION ");
whereList.forEach(whereSql -> sqlJoiner.add(sqlPrefix + " " + whereSql.getSql() + " " + suffix));
return sqlJoiner.toString();
}

/**
* each pk is a condition.the result will like :" [(id,userCode) in ((?,?),(?,?)), (id,userCode) in ((?,?),(?,?)
* ), (id,userCode) in ((?,?))]"
* Build where condition by pks string. size default MAX_IN_SIZE
*
* @param pkNameList pk column name list
* @param rowSize the row size of records
* @param dbType the type of database
* @return return where condition sql list.the sql can search all related records not just one.
*/
public static List<WhereSql> buildWhereConditionListByPKs(List<String> pkNameList, int rowSize, String dbType) {
return buildWhereConditionListByPKs(pkNameList, rowSize, dbType, MAX_IN_SIZE);
}

/**
* each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?)
* ) or (id,userCode) in ((?,?))"
* each pk is a condition.the result will like :" [(id,userCode) in ((?,?),(?,?)), (id,userCode) in ((?,?),(?,?)
* ), (id,userCode) in ((?,?))]"
* Build where condition by pks string.
*
* @param pkNameList pk column name list
* @param rowSize the row size of records
* @param dbType the type of database
* @param maxInSize the max in size
* @return return where condition sql string.the sql can search all related records not just one.
* @throws SQLException the sql exception
* @return return where condition sql list.the sql can search all related records not just one.
*/
public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSize, String dbType, int maxInSize)
throws SQLException {
StringBuilder whereStr = new StringBuilder();
public static List<WhereSql> buildWhereConditionListByPKs(List<String> pkNameList, int rowSize, String dbType, int maxInSize) {
List<WhereSql> whereSqls = new ArrayList<>();
//we must consider the situation of composite primary key
int batchSize = rowSize % maxInSize == 0 ? rowSize / maxInSize : (rowSize / maxInSize) + 1;
for (int batch = 0; batch < batchSize; batch++) {
if (batch > 0) {
whereStr.append(" or ");
}
StringBuilder whereStr = new StringBuilder();
whereStr.append("(");
for (int i = 0; i < pkNameList.size(); i++) {
if (i > 0) {
Expand All @@ -88,9 +111,10 @@ public static String buildWhereConditionByPKs(List<String> pkNameList, int rowSi
whereStr.append(")");
}
whereStr.append(" )");
whereSqls.add(new WhereSql(whereStr.toString(), eachSize, pkNameList.size()));
}

return whereStr.toString();
return whereSqls;
}

/**
Expand Down Expand Up @@ -135,4 +159,38 @@ public static String buildWhereConditionByPKs(List<String> pkNameList, String db
return whereStr.toString();
}

public static class WhereSql {
/**
* sql
*/
private final String sql;

/**
* row size
*/
private final int rowSize;

/**
* pk size
*/
private final int pkSize;

public WhereSql(String sql, int rowSize, int pkSize) {
this.sql = sql;
this.rowSize = rowSize;
this.pkSize = pkSize;
}

public String getSql() {
return sql;
}

public int getRowSize() {
return rowSize;
}

public int getPkSize() {
return pkSize;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,15 @@ protected TableRecords buildTableRecords(Map<String, List<Object>> pkValuesMap)
// build check sql
String firstKey = pkValuesMap.keySet().stream().findFirst().get();
int rowSize = pkValuesMap.get(firstKey).size();
suffix.append(WHERE).append(SqlGenerateUtils.buildWhereConditionByPKs(pkColumnNameList, rowSize, getDbType()));
suffix.append(WHERE);
StringJoiner selectSQLJoin = new StringJoiner(", ", prefix, suffix.toString());
List<String> insertColumnsUnEscape = recognizer.getInsertColumnsUnEscape();
List<String> needColumns =
getNeedColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), insertColumnsUnEscape);
needColumns.forEach(selectSQLJoin::add);
String sqlStr = SqlGenerateUtils.buildSQLByPKs(selectSQLJoin.toString(), "", pkColumnNameList, rowSize, getDbType());
ResultSet rs = null;
try (PreparedStatement ps = statementProxy.getConnection().prepareStatement(selectSQLJoin.toString())) {
try (PreparedStatement ps = statementProxy.getConnection().prepareStatement(sqlStr)) {

int paramIndex = 1;
for (int r = 0; r < rowSize; r++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
SQLUpdateRecognizer sqlUpdateRecognizer = (SQLUpdateRecognizer) sqlRecognizer;
updateColumnsSet.addAll(sqlUpdateRecognizer.getUpdateColumnsUnEscape());
}
StringBuilder prefix = new StringBuilder("SELECT ");
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
StringJoiner selectSQLJoiner = new StringJoiner(", ", "SELECT ",
" FROM " + getFromTableInSQL() + " WHERE ");
if (ONLY_CARE_UPDATE_COLUMNS) {
if (!containsPK(new ArrayList<>(updateColumnsSet))) {
selectSQLJoiner.add(getColumnNamesInSQL(tableMeta.getEscapePkNameList(getDbType())));
Expand All @@ -163,6 +162,6 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
selectSQLJoiner.add(ColumnUtils.addEscape(columnName, getDbType()));
}
}
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), "", tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,12 @@ protected TableRecords afterImage(TableRecords beforeImage) throws SQLException

private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage) throws SQLException {
String prefix = "SELECT ";
String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + whereSql;
String suffix = " FROM " + getFromTableInSQL() + " WHERE ";
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix, suffix);
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
List<String> needUpdateColumns = getNeedColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsUnEscape());
needUpdateColumns.forEach(selectSQLJoiner::add);
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), "", tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,12 @@ private String buildAfterImageSQL(String joinTable, String itemTable,
TableMeta itemTableMeta = getTableMeta(itemTable);
StringBuilder prefix = new StringBuilder("SELECT ");
List<String> pkColumns = getColumnNamesWithTablePrefixList(itemTable, recognizer.getTableAlias(itemTable), itemTableMeta.getPrimaryKeyOnlyName());
String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(pkColumns, beforeImage.pkRows().size(), getDbType());
String suffix = " FROM " + joinTable + " WHERE " + whereSql;
//maybe duplicate row for select join sql.remove duplicate row by 'group by' condition
suffix += GROUP_BY;
String suffix = " FROM " + joinTable + " WHERE ";
List<String> itemTableUpdateColumns = getItemUpdateColumns(itemTableMeta, recognizer.getUpdateColumns());
List<String> needUpdateColumns = getNeedColumns(itemTable, recognizer.getTableAlias(itemTable), itemTableUpdateColumns);
suffix += buildGroupBy(pkColumns, needUpdateColumns);
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
needUpdateColumns.forEach(selectSQLJoiner::add);
return selectSQLJoiner.toString();
return SqlGenerateUtils.buildSQLByPKs(selectSQLJoiner.toString(), GROUP_BY + buildGroupBy(pkColumns, needUpdateColumns), pkColumns, beforeImage.pkRows().size(), getDbType());
}

private List<String> getItemUpdateColumns(TableMeta itemTableMeta, List<String> updateColumns) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,30 +302,36 @@ protected TableRecords queryCurrentRecords(ConnectionProxy connectionProxy) thro
// build check sql
String firstKey = pkRowValues.keySet().stream().findFirst().get();
int pkRowSize = pkRowValues.get(firstKey).size();
String checkSQL = String.format(CHECK_SQL_TEMPLATE, sqlUndoLog.getTableName(),
SqlGenerateUtils.buildWhereConditionByPKs(pkNameList, pkRowSize, connectionProxy.getDbType()));

PreparedStatement statement = null;
ResultSet checkSet = null;
TableRecords currentRecords;
try {
statement = conn.prepareStatement(checkSQL);
int paramIndex = 1;
int rowSize = pkRowValues.get(pkNameList.get(0)).size();
for (int r = 0; r < rowSize; r++) {
for (int c = 0; c < pkNameList.size(); c++) {
List<Field> pkColumnValueList = pkRowValues.get(pkNameList.get(c));
Field field = pkColumnValueList.get(r);
int dataType = tableMeta.getColumnMeta(field.getName()).getDataType();
statement.setObject(paramIndex, field.getValue(), dataType);
paramIndex++;
List<SqlGenerateUtils.WhereSql> sqlConditions = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, pkRowSize, connectionProxy.getDbType());
List<TableRecords> currentRecordsList = new ArrayList<>();
int totalRowIndex = 0;
for (SqlGenerateUtils.WhereSql sqlCondition : sqlConditions) {
String checkSQL = String.format(CHECK_SQL_TEMPLATE, sqlUndoLog.getTableName(), sqlCondition.getSql());
PreparedStatement statement = null;
ResultSet checkSet = null;
try {
statement = conn.prepareStatement(checkSQL);
int paramIndex = 1;
for (int r = 0; r < sqlCondition.getRowSize(); r++) {
for (int c = 0; c < sqlCondition.getPkSize(); c++) {
List<Field> pkColumnValueList = pkRowValues.get(pkNameList.get(c));
Field field = pkColumnValueList.get(totalRowIndex + r);
int dataType = tableMeta.getColumnMeta(field.getName()).getDataType();
statement.setObject(paramIndex, field.getValue(), dataType);
paramIndex++;
}
}
}
totalRowIndex += sqlCondition.getRowSize();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码的意义是什么?
What is the significance of this code?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当主键超过1000时,sql会拆分成多条,并且按拆分后的sql循环执行,但是所有pk参数都在pkRowValues中,totalRowIndex是为了记录每条sql执行完成之后,在pkRowValues已经取到的参数行位置,做为下一条sql执行时取参数的起始值,其中sqlCondition.getRowSize()为当前sql参数行数,该行代码就为移动参数行。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当主键超过1000时,sql会拆分成多条,并且按拆分后的sql循环执行,但是所有pk参数都在pkRowValues中,totalRowIndex是为了记录每条sql执行完成之后,在pkRowValues已经取到的参数行位置,做为下一条sql执行时取参数的起始值,其中sqlCondition.getRowSize()为当前sql参数行数,该行代码就为移动参数行。

ok,再github review页没看出来,整体pr改动我看了下应该问题不大, 我拉到本地测试一下,测试通过的话就可以合并进去了

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没问题了,能否也放到1.8.1

Copy link
Contributor

@funky-eyes funky-eyes Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没问题了,能否也放到1.8.1

1.8.1应该不会进行发布了

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那要用这个只能升级到2.x了吗?后续1.x都不维护了吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那要用这个只能升级到2.x了吗?后续1.x都不维护了吗?

没有维护的必要性,因为1.x的问题已经在2.x上修复了,已经是一种迭代了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那要用这个只能升级到2.x了吗?后续1.x都不维护了吗?

这个pr能否根据2.x的代码进行提交一个新pr?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


checkSet = statement.executeQuery();
currentRecords = TableRecords.buildRecords(tableMeta, checkSet);
} finally {
IOUtil.close(checkSet, statement);
checkSet = statement.executeQuery();
currentRecordsList.add(TableRecords.buildRecords(tableMeta, checkSet));
} finally {
IOUtil.close(checkSet, statement);
}
}
TableRecords currentRecords = new TableRecords(tableMeta);
for (TableRecords tableRecords : currentRecordsList) {
tableRecords.getRows().forEach(currentRecords::add);
remind marked this conversation as resolved.
Show resolved Hide resolved
}
return currentRecords;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/
package io.seata.rm.datasource;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand All @@ -27,17 +27,38 @@
*/
class SqlGenerateUtilsTest {

@Test
void testBuildWhereConditionListByPKs() {
List<String> pkNameList = new ArrayList<>();
pkNameList.add("id");
pkNameList.add("name");
List<SqlGenerateUtils.WhereSql> results1 = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 4, "mysql", 2);
Assertions.assertEquals(2, results1.size());
results1.forEach(result -> {
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", result.getSql());
Assertions.assertEquals(2, result.getRowSize());
Assertions.assertEquals(2, result.getPkSize());
});
List<SqlGenerateUtils.WhereSql> results2 = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 5, "mysql", 2);
Assertions.assertEquals(3, results2.size());
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", results2.get(0).getSql());
Assertions.assertEquals(2, results2.get(0).getRowSize());
Assertions.assertEquals(2, results2.get(0).getPkSize());
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) )", results2.get(1).getSql());
Assertions.assertEquals("(id,name) in ( (?,?) )", results2.get(2).getSql());
Assertions.assertEquals(1, results2.get(2).getRowSize());
Assertions.assertEquals(2, results2.get(2).getPkSize());
}

@Test
void testBuildWhereConditionByPKs() throws SQLException {
List<String> pkNameList=new ArrayList<>();
void testBuildSQLByPKs() {
String sqlPrefix = "select id,name from t_order where ";
List<String> pkNameList = new ArrayList<>();
pkNameList.add("id");
pkNameList.add("name");
String result = SqlGenerateUtils.buildWhereConditionByPKs(pkNameList,4,"mysql",2);
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?),(?,?) )", result);
result = SqlGenerateUtils.buildWhereConditionByPKs(pkNameList,5,"mysql",2);
Assertions.assertEquals("(id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?),(?,?) ) or (id,name) in ( (?,?)"
+ " )",
result);
List<SqlGenerateUtils.WhereSql> whereList = SqlGenerateUtils.buildWhereConditionListByPKs(pkNameList, 4, "mysql", 2);
StringJoiner sqlJoiner = new StringJoiner(" union ");
whereList.forEach(whereSql -> sqlJoiner.add(sqlPrefix + " " + whereSql.getSql()));
Assertions.assertEquals("select id,name from t_order where (id,name) in ( (?,?),(?,?) ) union select id,name from t_order where (id,name) in ( (?,?),(?,?) )", sqlJoiner.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void testUpdateJoinUndoLog() throws SQLException {
};
Object[][] beforeReturnValue = new Object[][]{
new Object[]{1, "Tom"},
new Object[]{2, "Tony"},
};
StatementProxy beforeMockStatementProxy = mockStatementProxy(returnValueColumnLabels, beforeReturnValue, columnMetas, indexMetas);
String sql = "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL'";
Expand All @@ -68,6 +69,7 @@ public void testUpdateJoinUndoLog() throws SQLException {
TableRecords beforeImage = mySQLUpdateJoinExecutor.beforeImage();
Object[][] afterReturnValue = new Object[][]{
new Object[]{1, "WILL"},
new Object[]{2, "Tony"},
};
StatementProxy afterMockStatementProxy = mockStatementProxy(returnValueColumnLabels, afterReturnValue, columnMetas, indexMetas);
mySQLUpdateJoinExecutor.statementProxy = afterMockStatementProxy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public ResultSet executeQuery(MockStatementBase statement, String sql) throws SQ
List<Object[]> metas = new ArrayList<>();
if(asts.get(0) instanceof SQLSelectStatement) {
SQLSelectStatement ast = (SQLSelectStatement) asts.get(0);
SQLSelectQueryBlock queryBlock = ast.getSelect().getQueryBlock();
SQLSelectQueryBlock queryBlock = ast.getSelect().getFirstQueryBlock();
String tableName = "";
if (queryBlock.getFrom() instanceof SQLExprTableSource) {
MySQLSelectForUpdateRecognizer recognizer = new MySQLSelectForUpdateRecognizer(sql, ast);
Expand Down
Loading
Loading