Skip to content

Commit

Permalink
skip dml/ddl regex on select
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>
  • Loading branch information
timvaillancourt committed Oct 29, 2024
1 parent d7e7280 commit 157326f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 41 deletions.
43 changes: 24 additions & 19 deletions go/vt/external/golib/sqlutils/sqlite_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,30 @@ func ToSqlite3Insert(statement string) string {

// ToSqlite3Dialect converts a statement to sqlite3 dialect. The statement
// is checked in this order:
// 1. If an insert/replace, convert with ToSqlite3Insert.
// 2. If a create table, convert with IsCreateTable.
// 3. If an alter table, convert with IsAlterTable.
// 4. As fallback, return the statement with sqlite3GeneralConversions applied.
func ToSqlite3Dialect(statement string) (translated string) {
if IsInsert(statement) {
return ToSqlite3Insert(statement)
}
if IsCreateIndex(statement) {
return ToSqlite3CreateIndex(statement)
}
if IsDropIndex(statement) {
return ToSqlite3DropIndex(statement)
}
if IsCreateTable(statement) {
return ToSqlite3CreateTable(statement)
}
if IsAlterTable(statement) {
return ToSqlite3CreateTable(statement)
// 1. If a query, return the statement with sqlite3GeneralConversions applied.
// 2. If an insert/replace, convert with ToSqlite3Insert.
// 3. If a create index, convert with IsCreateIndex.
// 4. If an drop table, convert with IsDropIndex.
// 5. If a create table, convert with IsCreateTable.
// 6. If an alter table, convert with IsAlterTable.
// 7. As fallback, return the statement with sqlite3GeneralConversions applied.
func ToSqlite3Dialect(statement string, potentiallyDMLOrDDL bool) (translated string) {
if potentiallyDMLOrDDL {
if IsInsert(statement) {
return ToSqlite3Insert(statement)
}
if IsCreateIndex(statement) {
return ToSqlite3CreateIndex(statement)
}
if IsDropIndex(statement) {
return ToSqlite3DropIndex(statement)
}
if IsCreateTable(statement) {
return ToSqlite3CreateTable(statement)
}
if IsAlterTable(statement) {
return ToSqlite3CreateTable(statement)
}
}
return applyConversions(statement, sqlite3GeneralConversions)
}
45 changes: 26 additions & 19 deletions go/vt/external/golib/sqlutils/sqlite_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
database_instance
ADD COLUMN sql_delay INT UNSIGNED NOT NULL AFTER replica_lag_seconds
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
ALTER TABLE
database_instance
Expand All @@ -105,7 +105,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
database_instance
ADD INDEX source_host_port_idx (source_host, source_port)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
source_host_port_idx_database_instance
Expand All @@ -118,7 +118,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
topology_recovery
ADD KEY last_detection_idx (last_detection_id)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
last_detection_idx_topology_recovery
Expand All @@ -135,7 +135,7 @@ func TestCreateIndex(t *testing.T) {
source_host_port_idx_database_instance
on database_instance (source_host(128), source_port)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
source_host_port_idx_database_instance
Expand Down Expand Up @@ -174,7 +174,7 @@ func TestToSqlite3Insert(t *testing.T) {
domain_name=values(domain_name),
last_registered=values(last_registered)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
replace into
cluster_domain_name (cluster_name, domain_name, last_registered)
Expand All @@ -187,62 +187,62 @@ func TestToSqlite3Insert(t *testing.T) {
func TestToSqlite3GeneralConversions(t *testing.T) {
{
statement := "select now()"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now')")
}
{
statement := "select now() - interval ? second"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', printf('-%d second', ?))")
}
{
statement := "select now() + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', printf('+%d minute', ?))")
}
{
statement := "select now() + interval 5 minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', '+5 minute')")
}
{
statement := "select some_table.some_column + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime(some_table.some_column, printf('+%d minute', ?))")
}
{
statement := "AND primary_instance.last_attempted_check <= primary_instance.last_seen + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "AND primary_instance.last_attempted_check <= datetime(primary_instance.last_seen, printf('+%d minute', ?))")
}
{
statement := "select concat(primary_instance.port, '') as port"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select (primary_instance.port || '') as port")
}
{
statement := "select concat( 'abc' , 'def') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select ('abc' || 'def') as s")
}
{
statement := "select concat( 'abc' , 'def', last.col) as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select ('abc' || 'def' || last.col) as s")
}
{
statement := "select concat(myself.only) as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select concat(myself.only) as s")
}
{
statement := "select concat(1, '2', 3, '4') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select concat(1, '2', 3, '4') as s")
}
{
statement := "select group_concat( 'abc' , 'def') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select group_concat( 'abc' , 'def') as s")
}
}
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestToSqlite3Dialect(t *testing.T) {

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result := ToSqlite3Dialect(test.input)
result := ToSqlite3Dialect(test.input, true)
assert.Equal(t, test.expected, result)
})
}
Expand Down Expand Up @@ -340,6 +340,13 @@ func BenchmarkToSqlite3Dialect_Insert1000(b *testing.B) {
b.StopTimer()
statement := buildToSqlite3Dialect_Insert(1000)
b.StartTimer()
ToSqlite3Dialect(statement)
ToSqlite3Dialect(statement, true)
}
}

func BenchmarkToSqlite3Dialect_Select(b *testing.B) {
for i := 0; i < b.N; i++ {
statement := "select now() - interval ? second"
ToSqlite3Dialect(statement, false)
}
}
10 changes: 7 additions & 3 deletions go/vt/vtorc/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ func OpenVTOrc() (db *sql.DB, err error) {
}

func translateStatement(statement string) string {
return sqlutils.ToSqlite3Dialect(statement)
return sqlutils.ToSqlite3Dialect(statement, true /* potentiallyDMLOrDDL */)
}

func translateQueryStatement(statement string) string {
return sqlutils.ToSqlite3Dialect(statement, false /* potentiallyDMLOrDDL */)
}

// registerVTOrcDeployment updates the vtorc_db_deployments table upon successful deployment
Expand Down Expand Up @@ -162,7 +166,7 @@ func ExecVTOrc(query string, args ...any) (sql.Result, error) {

// QueryVTOrcRowsMap
func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error {
query = translateStatement(query)
query = translateQueryStatement(query)
db, err := OpenVTOrc()
if err != nil {
return err
Expand All @@ -173,7 +177,7 @@ func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error {

// QueryVTOrc
func QueryVTOrc(query string, argsArray []any, onRow func(sqlutils.RowMap) error) error {
query = translateStatement(query)
query = translateQueryStatement(query)
db, err := OpenVTOrc()
if err != nil {
return err
Expand Down

0 comments on commit 157326f

Please sign in to comment.