Skip to content

Commit

Permalink
slack-19.0: optimize sqlutils.ToSqlite3Dialect, part 2 (#546)
Browse files Browse the repository at this point in the history
* Optimize `sqlutils.ToSqlite3Dialect` for inserts

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>

* gofmt

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>

* separate index matches

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>

* skip dml/ddl regex on select

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>

* switch block instead

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>

---------

Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>
  • Loading branch information
timvaillancourt authored Oct 29, 2024
1 parent 32f02b7 commit a768745
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 43 deletions.
45 changes: 36 additions & 9 deletions go/vt/external/golib/sqlutils/sqlite_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,13 @@ var sqlite3GeneralConversions = []regexpMap{
rmap(`(?i)\bconcat[(][\s]*([^,)]+)[\s]*,[\s]*([^,)]+)[\s]*,[\s]*([^,)]+)[\s]*[)]`, `($1 || $2 || $3)`),

rmap(`(?i) rlike `, ` like `),
}

var sqlite3CreateIndexConversions = []regexpMap{
rmap(`(?i)create index([\s\S]+)[(][\s]*[0-9]+[\s]*[)]([\s\S]+)`, `create index ${1}${2}`),
}

var sqlite3DropIndexConversions = []regexpMap{
rmap(`(?i)drop index ([\S]+) on ([\S]+)`, `drop index if exists $1`),
}

Expand Down Expand Up @@ -115,20 +120,42 @@ func ToSqlite3CreateTable(statement string) string {
return applyConversions(statement, sqlite3CreateTableConversions)
}

func ToSqlite3CreateIndex(statement string) string {
return applyConversions(statement, sqlite3CreateIndexConversions)
}

func ToSqlite3DropIndex(statement string) string {
return applyConversions(statement, sqlite3DropIndexConversions)
}

func ToSqlite3Insert(statement string) string {
statement = applyConversions(statement, sqlite3GeneralConversions)
return applyConversions(statement, sqlite3InsertConversions)
}

func ToSqlite3Dialect(statement string) (translated string) {
if IsInsert(statement) {
return ToSqlite3Insert(statement)
}
if IsCreateTable(statement) {
return ToSqlite3CreateTable(statement)
}
if IsAlterTable(statement) {
return ToSqlite3CreateTable(statement)
// ToSqlite3Dialect converts a statement to sqlite3 dialect. The statement
// is checked in this order:
// 1. If a query, return the statement with sqlite3GeneralConversions applied.
// 2. If an insert/replace, convert with ToSqlite3Insert.
// 3. If a create index, convert with IsCreateIndex.
// 4. If an drop table, convert with IsDropIndex.
// 5. If a create table, convert with IsCreateTable.
// 6. If an alter table, convert with IsAlterTable.
// 7. As fallback, return the statement with sqlite3GeneralConversions applied.
func ToSqlite3Dialect(statement string, potentiallyDMLOrDDL bool) (translated string) {
if potentiallyDMLOrDDL {
switch {
case IsInsert(statement):
return ToSqlite3Insert(statement)
case IsCreateIndex(statement):
return ToSqlite3CreateIndex(statement)
case IsDropIndex(statement):
return ToSqlite3DropIndex(statement)
case IsCreateTable(statement):
return ToSqlite3CreateTable(statement)
case IsAlterTable(statement):
return ToSqlite3CreateTable(statement)
}
}
return applyConversions(statement, sqlite3GeneralConversions)
}
81 changes: 50 additions & 31 deletions go/vt/external/golib/sqlutils/sqlite_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package sqlutils

import (
"fmt"
"regexp"
"strings"
"testing"
Expand Down Expand Up @@ -91,7 +92,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
database_instance
ADD COLUMN sql_delay INT UNSIGNED NOT NULL AFTER replica_lag_seconds
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
ALTER TABLE
database_instance
Expand All @@ -104,7 +105,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
database_instance
ADD INDEX source_host_port_idx (source_host, source_port)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
source_host_port_idx_database_instance
Expand All @@ -117,7 +118,7 @@ func TestToSqlite3AlterTable(t *testing.T) {
topology_recovery
ADD KEY last_detection_idx (last_detection_id)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
last_detection_idx_topology_recovery
Expand All @@ -134,7 +135,7 @@ func TestCreateIndex(t *testing.T) {
source_host_port_idx_database_instance
on database_instance (source_host(128), source_port)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
create index
source_host_port_idx_database_instance
Expand Down Expand Up @@ -173,7 +174,7 @@ func TestToSqlite3Insert(t *testing.T) {
domain_name=values(domain_name),
last_registered=values(last_registered)
`
result := stripSpaces(ToSqlite3Dialect(statement))
result := stripSpaces(ToSqlite3Dialect(statement, true))
require.Equal(t, result, stripSpaces(`
replace into
cluster_domain_name (cluster_name, domain_name, last_registered)
Expand All @@ -186,62 +187,62 @@ func TestToSqlite3Insert(t *testing.T) {
func TestToSqlite3GeneralConversions(t *testing.T) {
{
statement := "select now()"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now')")
}
{
statement := "select now() - interval ? second"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', printf('-%d second', ?))")
}
{
statement := "select now() + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', printf('+%d minute', ?))")
}
{
statement := "select now() + interval 5 minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime('now', '+5 minute')")
}
{
statement := "select some_table.some_column + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select datetime(some_table.some_column, printf('+%d minute', ?))")
}
{
statement := "AND primary_instance.last_attempted_check <= primary_instance.last_seen + interval ? minute"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "AND primary_instance.last_attempted_check <= datetime(primary_instance.last_seen, printf('+%d minute', ?))")
}
{
statement := "select concat(primary_instance.port, '') as port"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select (primary_instance.port || '') as port")
}
{
statement := "select concat( 'abc' , 'def') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select ('abc' || 'def') as s")
}
{
statement := "select concat( 'abc' , 'def', last.col) as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select ('abc' || 'def' || last.col) as s")
}
{
statement := "select concat(myself.only) as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select concat(myself.only) as s")
}
{
statement := "select concat(1, '2', 3, '4') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select concat(1, '2', 3, '4') as s")
}
{
statement := "select group_concat( 'abc' , 'def') as s"
result := ToSqlite3Dialect(statement)
result := ToSqlite3Dialect(statement, false)
require.Equal(t, result, "select group_concat( 'abc' , 'def') as s")
}
}
Expand Down Expand Up @@ -307,27 +308,45 @@ func TestToSqlite3Dialect(t *testing.T) {

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
result := ToSqlite3Dialect(test.input)
result := ToSqlite3Dialect(test.input, true)
assert.Equal(t, test.expected, result)
})
}
}

func BenchmarkToSqlite3Dialect_Insert(b *testing.B) {
statement := `INSERT ignore INTO database_instance
(alias, hostname, port, last_checked, last_attempted_check, last_check_partial_success, server_id, server_uuid,
version, major_version, version_comment, binlog_server, read_only, binlog_format,
binlog_row_image, log_bin, log_replica_updates, binary_log_file, binary_log_pos, source_host, source_port, replica_net_timeout, heartbeat_interval,
replica_sql_running, replica_io_running, replication_sql_thread_state, replication_io_thread_state, has_replication_filters, supports_oracle_gtid, oracle_gtid, source_uuid, ancestry_uuid, executed_gtid_set, gtid_mode, gtid_purged, gtid_errant, mariadb_gtid, pseudo_gtid,
source_log_file, read_source_log_pos, relay_source_log_file, exec_source_log_pos, relay_log_file, relay_log_pos, last_sql_error, last_io_error, replication_lag_seconds, replica_lag_seconds, sql_delay, data_center, region, physical_environment, replication_depth, is_co_primary, has_replication_credentials, allow_tls, semi_sync_enforced, semi_sync_primary_enabled, semi_sync_primary_timeout, semi_sync_primary_wait_for_replica_count, semi_sync_replica_enabled, semi_sync_primary_status, semi_sync_primary_clients, semi_sync_replica_status, last_discovery_latency, last_seen)
func buildToSqlite3Dialect_Insert(instances int) string {
var rows []string
for i := 0; i < instances; i++ {
rows = append(rows, `(?, ?, ?, NOW(), NOW(), 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NOW())`)
}

return fmt.Sprintf(`INSERT ignore INTO database_instance
(alias, hostname, port, last_checked, last_attempted_check, last_check_partial_success, server_id, server_uuid,
version, major_version, version_comment, binlog_server, read_only, binlog_format,
binlog_row_image, log_bin, log_replica_updates, binary_log_file, binary_log_pos, source_host, source_port, replica_net_timeout, heartbeat_interval,
replica_sql_running, replica_io_running, replication_sql_thread_state, replication_io_thread_state, has_replication_filters, supports_oracle_gtid, oracle_gtid, source_uuid, ancestry_uuid, executed_gtid_set, gtid_mode, gtid_purged, gtid_errant, mariadb_gtid, pseudo_gtid,
source_log_file, read_source_log_pos, relay_source_log_file, exec_source_log_pos, relay_log_file, relay_log_pos, last_sql_error, last_io_error, replication_lag_seconds, replica_lag_seconds, sql_delay, data_center, region, physical_environment, replication_depth, is_co_primary, has_replication_credentials, allow_tls, semi_sync_enforced, semi_sync_primary_enabled, semi_sync_primary_timeout, semi_sync_primary_wait_for_replica_count, semi_sync_replica_enabled, semi_sync_primary_status, semi_sync_primary_clients, semi_sync_replica_status, last_discovery_latency, last_seen)
VALUES
(?, ?, ?, NOW(), NOW(), 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NOW())
%s
ON DUPLICATE KEY UPDATE
alias=VALUES(alias), hostname=VALUES(hostname), port=VALUES(port), last_checked=VALUES(last_checked), last_attempted_check=VALUES(last_attempted_check), last_check_partial_success=VALUES(last_check_partial_success), server_id=VALUES(server_id), server_uuid=VALUES(server_uuid), version=VALUES(version), major_version=VALUES(major_version), version_comment=VALUES(version_comment), binlog_server=VALUES(binlog_server), read_only=VALUES(read_only), binlog_format=VALUES(binlog_format), binlog_row_image=VALUES(binlog_row_image), log_bin=VALUES(log_bin), log_replica_updates=VALUES(log_replica_updates), binary_log_file=VALUES(binary_log_file), binary_log_pos=VALUES(binary_log_pos), source_host=VALUES(source_host), source_port=VALUES(source_port), replica_net_timeout=VALUES(replica_net_timeout), heartbeat_interval=VALUES(heartbeat_interval), replica_sql_running=VALUES(replica_sql_running), replica_io_running=VALUES(replica_io_running), replication_sql_thread_state=VALUES(replication_sql_thread_state), replication_io_thread_state=VALUES(replication_io_thread_state), has_replication_filters=VALUES(has_replication_filters), supports_oracle_gtid=VALUES(supports_oracle_gtid), oracle_gtid=VALUES(oracle_gtid), source_uuid=VALUES(source_uuid), ancestry_uuid=VALUES(ancestry_uuid), executed_gtid_set=VALUES(executed_gtid_set), gtid_mode=VALUES(gtid_mode), gtid_purged=VALUES(gtid_purged), gtid_errant=VALUES(gtid_errant), mariadb_gtid=VALUES(mariadb_gtid), pseudo_gtid=VALUES(pseudo_gtid), source_log_file=VALUES(source_log_file), read_source_log_pos=VALUES(read_source_log_pos), relay_source_log_file=VALUES(relay_source_log_file), exec_source_log_pos=VALUES(exec_source_log_pos), relay_log_file=VALUES(relay_log_file), relay_log_pos=VALUES(relay_log_pos), last_sql_error=VALUES(last_sql_error), last_io_error=VALUES(last_io_error), replication_lag_seconds=VALUES(replication_lag_seconds), replica_lag_seconds=VALUES(replica_lag_seconds), sql_delay=VALUES(sql_delay), data_center=VALUES(data_center), region=VALUES(region), physical_environment=VALUES(physical_environment), replication_depth=VALUES(replication_depth), is_co_primary=VALUES(is_co_primary), has_replication_credentials=VALUES(has_replication_credentials), allow_tls=VALUES(allow_tls),
semi_sync_enforced=VALUES(semi_sync_enforced), semi_sync_primary_enabled=VALUES(semi_sync_primary_enabled), semi_sync_primary_timeout=VALUES(semi_sync_primary_timeout), semi_sync_primary_wait_for_replica_count=VALUES(semi_sync_primary_wait_for_replica_count), semi_sync_replica_enabled=VALUES(semi_sync_replica_enabled), semi_sync_primary_status=VALUES(semi_sync_primary_status), semi_sync_primary_clients=VALUES(semi_sync_primary_clients), semi_sync_replica_status=VALUES(semi_sync_replica_status),
last_discovery_latency=VALUES(last_discovery_latency), last_seen=VALUES(last_seen)
`
alias=VALUES(alias), hostname=VALUES(hostname), port=VALUES(port), last_checked=VALUES(last_checked), last_attempted_check=VALUES(last_attempted_check), last_check_partial_success=VALUES(last_check_partial_success), server_id=VALUES(server_id), server_uuid=VALUES(server_uuid), version=VALUES(version), major_version=VALUES(major_version), version_comment=VALUES(version_comment), binlog_server=VALUES(binlog_server), read_only=VALUES(read_only), binlog_format=VALUES(binlog_format), binlog_row_image=VALUES(binlog_row_image), log_bin=VALUES(log_bin), log_replica_updates=VALUES(log_replica_updates), binary_log_file=VALUES(binary_log_file), binary_log_pos=VALUES(binary_log_pos), source_host=VALUES(source_host), source_port=VALUES(source_port), replica_net_timeout=VALUES(replica_net_timeout), heartbeat_interval=VALUES(heartbeat_interval), replica_sql_running=VALUES(replica_sql_running), replica_io_running=VALUES(replica_io_running), replication_sql_thread_state=VALUES(replication_sql_thread_state), replication_io_thread_state=VALUES(replication_io_thread_state), has_replication_filters=VALUES(has_replication_filters), supports_oracle_gtid=VALUES(supports_oracle_gtid), oracle_gtid=VALUES(oracle_gtid), source_uuid=VALUES(source_uuid), ancestry_uuid=VALUES(ancestry_uuid), executed_gtid_set=VALUES(executed_gtid_set), gtid_mode=VALUES(gtid_mode), gtid_purged=VALUES(gtid_purged), gtid_errant=VALUES(gtid_errant), mariadb_gtid=VALUES(mariadb_gtid), pseudo_gtid=VALUES(pseudo_gtid), source_log_file=VALUES(source_log_file), read_source_log_pos=VALUES(read_source_log_pos), relay_source_log_file=VALUES(relay_source_log_file), exec_source_log_pos=VALUES(exec_source_log_pos), relay_log_file=VALUES(relay_log_file), relay_log_pos=VALUES(relay_log_pos), last_sql_error=VALUES(last_sql_error), last_io_error=VALUES(last_io_error), replication_lag_seconds=VALUES(replication_lag_seconds), replica_lag_seconds=VALUES(replica_lag_seconds), sql_delay=VALUES(sql_delay), data_center=VALUES(data_center), region=VALUES(region), physical_environment=VALUES(physical_environment), replication_depth=VALUES(replication_depth), is_co_primary=VALUES(is_co_primary), has_replication_credentials=VALUES(has_replication_credentials), allow_tls=VALUES(allow_tls),
semi_sync_enforced=VALUES(semi_sync_enforced), semi_sync_primary_enabled=VALUES(semi_sync_primary_enabled), semi_sync_primary_timeout=VALUES(semi_sync_primary_timeout), semi_sync_primary_wait_for_replica_count=VALUES(semi_sync_primary_wait_for_replica_count), semi_sync_replica_enabled=VALUES(semi_sync_replica_enabled), semi_sync_primary_status=VALUES(semi_sync_primary_status), semi_sync_primary_clients=VALUES(semi_sync_primary_clients), semi_sync_replica_status=VALUES(semi_sync_replica_status),
last_discovery_latency=VALUES(last_discovery_latency), last_seen=VALUES(last_seen)
`, strings.Join(rows, "\n\t\t\t\t"))
}

func BenchmarkToSqlite3Dialect_Insert1000(b *testing.B) {
for i := 0; i < b.N; i++ {
ToSqlite3Dialect(statement)
b.StopTimer()
statement := buildToSqlite3Dialect_Insert(1000)
b.StartTimer()
ToSqlite3Dialect(statement, true)
}
}

func BenchmarkToSqlite3Dialect_Select(b *testing.B) {
for i := 0; i < b.N; i++ {
statement := "select now() - interval ? second"
ToSqlite3Dialect(statement, false)
}
}
10 changes: 7 additions & 3 deletions go/vt/vtorc/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ func OpenVTOrc() (db *sql.DB, err error) {
}

func translateStatement(statement string) string {
return sqlutils.ToSqlite3Dialect(statement)
return sqlutils.ToSqlite3Dialect(statement, true /* potentiallyDMLOrDDL */)
}

func translateQueryStatement(statement string) string {
return sqlutils.ToSqlite3Dialect(statement, false /* potentiallyDMLOrDDL */)
}

// registerVTOrcDeployment updates the vtorc_db_deployments table upon successful deployment
Expand Down Expand Up @@ -162,7 +166,7 @@ func ExecVTOrc(query string, args ...any) (sql.Result, error) {

// QueryVTOrcRowsMap
func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error {
query = translateStatement(query)
query = translateQueryStatement(query)
db, err := OpenVTOrc()
if err != nil {
return err
Expand All @@ -173,7 +177,7 @@ func QueryVTOrcRowsMap(query string, onRow func(sqlutils.RowMap) error) error {

// QueryVTOrc
func QueryVTOrc(query string, argsArray []any, onRow func(sqlutils.RowMap) error) error {
query = translateStatement(query)
query = translateQueryStatement(query)
db, err := OpenVTOrc()
if err != nil {
return err
Expand Down

0 comments on commit a768745

Please sign in to comment.