diff --git a/go/vt/external/golib/sqlutils/sqlite_dialect.go b/go/vt/external/golib/sqlutils/sqlite_dialect.go index 55e61800b40..9b0888bdda1 100644 --- a/go/vt/external/golib/sqlutils/sqlite_dialect.go +++ b/go/vt/external/golib/sqlutils/sqlite_dialect.go @@ -135,25 +135,30 @@ func ToSqlite3Insert(statement string) string { // ToSqlite3Dialect converts a statement to sqlite3 dialect. The statement // is checked in this order: -// 1. If an insert/replace, convert with ToSqlite3Insert. -// 2. If a create table, convert with IsCreateTable. -// 3. If an alter table, convert with IsAlterTable. -// 4. As fallback, return the statement with sqlite3GeneralConversions applied. -func ToSqlite3Dialect(statement string) (translated string) { - if IsInsert(statement) { - return ToSqlite3Insert(statement) - } - if IsCreateIndex(statement) { - return ToSqlite3CreateIndex(statement) - } - if IsDropIndex(statement) { - return ToSqlite3DropIndex(statement) - } - if IsCreateTable(statement) { - return ToSqlite3CreateTable(statement) - } - if IsAlterTable(statement) { - return ToSqlite3CreateTable(statement) +// 1. If a query, return the statement with sqlite3GeneralConversions applied. +// 2. If an insert/replace, convert with ToSqlite3Insert. +// 3. If a create index, convert with IsCreateIndex. +// 4. If an drop table, convert with IsDropIndex. +// 5. If a create table, convert with IsCreateTable. +// 6. If an alter table, convert with IsAlterTable. +// 7. As fallback, return the statement with sqlite3GeneralConversions applied. +func ToSqlite3Dialect(statement string, potentiallyDMLOrDDL bool) (translated string) { + if potentiallyDMLOrDDL { + if IsInsert(statement) { + return ToSqlite3Insert(statement) + } + if IsCreateIndex(statement) { + return ToSqlite3CreateIndex(statement) + } + if IsDropIndex(statement) { + return ToSqlite3DropIndex(statement) + } + if IsCreateTable(statement) { + return ToSqlite3CreateTable(statement) + } + if IsAlterTable(statement) { + return ToSqlite3CreateTable(statement) + } } return applyConversions(statement, sqlite3GeneralConversions) } diff --git a/go/vt/external/golib/sqlutils/sqlite_dialect_test.go b/go/vt/external/golib/sqlutils/sqlite_dialect_test.go index 4a8512d35a1..cd6d1477d04 100644 --- a/go/vt/external/golib/sqlutils/sqlite_dialect_test.go +++ b/go/vt/external/golib/sqlutils/sqlite_dialect_test.go @@ -92,7 +92,7 @@ func TestToSqlite3AlterTable(t *testing.T) { database_instance ADD COLUMN sql_delay INT UNSIGNED NOT NULL AFTER replica_lag_seconds ` - result := stripSpaces(ToSqlite3Dialect(statement)) + result := stripSpaces(ToSqlite3Dialect(statement, true)) require.Equal(t, result, stripSpaces(` ALTER TABLE database_instance @@ -105,7 +105,7 @@ func TestToSqlite3AlterTable(t *testing.T) { database_instance ADD INDEX source_host_port_idx (source_host, source_port) ` - result := stripSpaces(ToSqlite3Dialect(statement)) + result := stripSpaces(ToSqlite3Dialect(statement, true)) require.Equal(t, result, stripSpaces(` create index source_host_port_idx_database_instance @@ -118,7 +118,7 @@ func TestToSqlite3AlterTable(t *testing.T) { topology_recovery ADD KEY last_detection_idx (last_detection_id) ` - result := stripSpaces(ToSqlite3Dialect(statement)) + result := stripSpaces(ToSqlite3Dialect(statement, true)) require.Equal(t, result, stripSpaces(` create index last_detection_idx_topology_recovery @@ -135,7 +135,7 @@ func TestCreateIndex(t *testing.T) { source_host_port_idx_database_instance on database_instance (source_host(128), source_port) ` - result := stripSpaces(ToSqlite3Dialect(statement)) + result := stripSpaces(ToSqlite3Dialect(statement, true)) require.Equal(t, result, stripSpaces(` create index source_host_port_idx_database_instance @@ -174,7 +174,7 @@ func TestToSqlite3Insert(t *testing.T) { domain_name=values(domain_name), last_registered=values(last_registered) ` - result := stripSpaces(ToSqlite3Dialect(statement)) + result := stripSpaces(ToSqlite3Dialect(statement, true)) require.Equal(t, result, stripSpaces(` replace into cluster_domain_name (cluster_name, domain_name, last_registered) @@ -187,62 +187,62 @@ func TestToSqlite3Insert(t *testing.T) { func TestToSqlite3GeneralConversions(t *testing.T) { { statement := "select now()" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select datetime('now')") } { statement := "select now() - interval ? second" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select datetime('now', printf('-%d second', ?))") } { statement := "select now() + interval ? minute" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select datetime('now', printf('+%d minute', ?))") } { statement := "select now() + interval 5 minute" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select datetime('now', '+5 minute')") } { statement := "select some_table.some_column + interval ? minute" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select datetime(some_table.some_column, printf('+%d minute', ?))") } { statement := "AND primary_instance.last_attempted_check <= primary_instance.last_seen + interval ? minute" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "AND primary_instance.last_attempted_check <= datetime(primary_instance.last_seen, printf('+%d minute', ?))") } { statement := "select concat(primary_instance.port, '') as port" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select (primary_instance.port || '') as port") } { statement := "select concat( 'abc' , 'def') as s" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select ('abc' || 'def') as s") } { statement := "select concat( 'abc' , 'def', last.col) as s" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select ('abc' || 'def' || last.col) as s") } { statement := "select concat(myself.only) as s" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select concat(myself.only) as s") } { statement := "select concat(1, '2', 3, '4') as s" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select concat(1, '2', 3, '4') as s") } { statement := "select group_concat( 'abc' , 'def') as s" - result := ToSqlite3Dialect(statement) + result := ToSqlite3Dialect(statement, false) require.Equal(t, result, "select group_concat( 'abc' , 'def') as s") } } @@ -308,7 +308,7 @@ func TestToSqlite3Dialect(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { - result := ToSqlite3Dialect(test.input) + result := ToSqlite3Dialect(test.input, true) assert.Equal(t, test.expected, result) }) } @@ -340,6 +340,13 @@ func BenchmarkToSqlite3Dialect_Insert1000(b *testing.B) { b.StopTimer() statement := buildToSqlite3Dialect_Insert(1000) b.StartTimer() - ToSqlite3Dialect(statement) + ToSqlite3Dialect(statement, true) + } +} + +func BenchmarkToSqlite3Dialect_Select(b *testing.B) { + for i := 0; i < b.N; i++ { + statement := "select now() - interval ? second" + ToSqlite3Dialect(statement, false) } } diff --git a/go/vt/vtorc/db/db.go b/go/vt/vtorc/db/db.go index 92657eddc3f..097d3732797 100644 --- a/go/vt/vtorc/db/db.go +++ b/go/vt/vtorc/db/db.go @@ -69,7 +69,11 @@ func OpenVTOrc() (db *sql.DB, err error) { } func translateStatement(statement string) string { - return sqlutils.ToSqlite3Dialect(statement) + return sqlutils.ToSqlite3Dialect(statement, true /* potentiallyDMLOrDDL */) +} + +func translateQueryStatement(statement string) string { + return sqlutils.ToSqlite3Dialect(statement, false /* potentiallyDMLOrDDL */) } // registerVTOrcDeployment updates the vtorc_db_deployments table upon successful deployment @@ -162,7 +166,7 @@ func ExecVTOrc(query string, args ...any) (sql.Result, error) { // QueryVTOrcRowsMap func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error { - query = translateStatement(query) + query = translateQueryStatement(query) db, err := OpenVTOrc() if err != nil { return err @@ -173,7 +177,7 @@ func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error { // QueryVTOrc func QueryVTOrc(query string, argsArray []any, onRow func(sqlutils.RowMap) error) error { - query = translateStatement(query) + query = translateQueryStatement(query) db, err := OpenVTOrc() if err != nil { return err