Skip to content

Commit

Permalink
Add ConnectionContext and ShardingSphereMetaData in QueryContext, cal…
Browse files Browse the repository at this point in the history
…l setCurrentDatabaseName in ConnectionSession (#31971)

* Add ConnectionContext and ShardingSphereMetaData in QueryContext, call setCurrentDatabaseName in ConnectionSession

* fix unit test

* fix unit test
  • Loading branch information
strongduanmu authored Jul 3, 2024
1 parent a022bc9 commit faad29e
Show file tree
Hide file tree
Showing 52 changed files with 260 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.agent.plugin.metrics.core.fixture.collector.MetricsCollectorFixture;
import org.apache.shardingsphere.infra.binder.context.statement.UnknownSQLStatementContext;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLDeleteStatement;
Expand Down Expand Up @@ -53,25 +54,29 @@ void reset() {

@Test
void assertInsertRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLInsertStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLInsertStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "INSERT=1");
}

@Test
void assertUpdateRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLUpdateStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLUpdateStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "UPDATE=1");
}

@Test
void assertDeleteRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLDeleteStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLDeleteStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "DELETE=1");
}

@Test
void assertSelectRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLSelectStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLSelectStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "SELECT=1");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
Expand Down Expand Up @@ -132,7 +133,8 @@ private ShardingSphereDatabase mockSingleDatabase() {
private QueryContext createQueryContext() {
CreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
return new QueryContext(new CreateTableStatementContext(createTableStatement, DefaultDatabase.LOGIC_NAME), "CREATE TABLE", new LinkedList<>(), new HintValueContext());
return new QueryContext(new CreateTableStatementContext(createTableStatement, DefaultDatabase.LOGIC_NAME), "CREATE TABLE", new LinkedList<>(), new HintValueContext(),
mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
}

private Map<String, DataSource> createMultiDataSourceMap() throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.instance.ComputeNodeInstanceContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
Expand Down Expand Up @@ -84,7 +85,8 @@ void setUp() {
@Test
void assertDecorateRouteContextToPrimaryDataSource() {
RouteContext actual = mockRouteContext();
QueryContext queryContext = new QueryContext(mock(SQLStatementContext.class), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext =
new QueryContext(mock(SQLStatementContext.class), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand All @@ -100,7 +102,7 @@ void assertDecorateRouteContextToReplicaDataSource() {
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(selectStatement.getLock()).thenReturn(Optional.empty());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand All @@ -116,7 +118,7 @@ void assertDecorateRouteContextToPrimaryDataSourceWithLock() {
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(selectStatement.getLock()).thenReturn(Optional.of(mock(LockSegment.class)));
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.shardingsphere.shadow.route.engine;

import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.DeleteStatementContext;
Expand Down Expand Up @@ -46,13 +48,17 @@ class ShadowRouteEngineFactoryTest {

@Test
void assertNewInstance() {
ShadowRouteEngine shadowInsertRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createInsertSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowInsertRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createInsertSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowInsertRouteEngine, instanceOf(ShadowInsertStatementRoutingEngine.class));
ShadowRouteEngine shadowUpdateRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createUpdateSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowUpdateRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createUpdateSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowUpdateRouteEngine, instanceOf(ShadowUpdateStatementRoutingEngine.class));
ShadowRouteEngine shadowDeleteRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createDeleteSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowDeleteRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createDeleteSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowDeleteRouteEngine, instanceOf(ShadowDeleteStatementRoutingEngine.class));
ShadowRouteEngine shadowSelectRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createSelectSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowSelectRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createSelectSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowSelectRouteEngine, instanceOf(ShadowSelectStatementRoutingEngine.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sharding.api.config.strategy.audit.ShardingAuditStrategyConfiguration;
import org.apache.shardingsphere.sharding.exception.audit.DMLWithoutShardingKeyException;
Expand Down Expand Up @@ -79,15 +81,17 @@ void setUp() {
@Test
void assertCheckSuccess() {
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)),
globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}

@Test
void assertCheckSuccessByDisableAuditNames() {
when(auditStrategy.isAllowHintDisable()).thenReturn(true);
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)),
globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1"), times(0)).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}

Expand All @@ -97,7 +101,8 @@ void assertCheckFailed() {
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
doThrow(new DMLWithoutShardingKeyException()).when(auditAlgorithm).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
DMLWithoutShardingKeyException ex = assertThrows(DMLWithoutShardingKeyException.class, () -> new ShardingSQLAuditor().audit(
new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule));
new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)), globalRuleMetaData,
databases.get("foo_db"), rule));
assertThat(ex.getMessage(), is("Not allow DML operation without sharding conditions."));
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
import org.apache.shardingsphere.infra.parser.sql.SQLStatementParserEngine;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.api.config.ShardingRuleConfiguration;
Expand Down Expand Up @@ -145,7 +146,7 @@ private ShardingSphereDatabase createDatabase(final ShardingRule shardingRule, f

private QueryContext createQueryContext(final ShardingSphereDatabase database, final String sql, final List<Object> params) {
SQLStatementContext sqlStatementContext = new SQLBindEngine(createShardingSphereMetaData(database), DATABASE_NAME, new HintValueContext()).bind(parse(sql), params);
return new QueryContext(sqlStatementContext, sql, params, new HintValueContext());
return new QueryContext(sqlStatementContext, sql, params, new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
}

private ShardingSphereMetaData createShardingSphereMetaData(final ShardingSphereDatabase database) {
Expand Down
Loading

0 comments on commit faad29e

Please sign in to comment.