From f7c229c675377435e17e4861b3e72b4e77819e15 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:05:17 -0600 Subject: [PATCH 1/5] [release-18.0] Flaky test fixes (#16940) (#16957) Signed-off-by: Manan Gupta Co-authored-by: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> --- go/mysql/server_flaky_test.go | 20 +++++- go/vt/mysqlctl/fakemysqldaemon.go | 25 ++++--- go/vt/vtctl/grpcvtctldserver/server_test.go | 2 +- .../vttablet/tabletmanager/framework_test.go | 4 +- go/vt/wrangler/testlib/backup_test.go | 66 +++++++++---------- .../testlib/emergency_reparent_shard_test.go | 26 ++++---- .../testlib/planned_reparent_shard_test.go | 26 ++++---- go/vt/wrangler/testlib/reparent_utils_test.go | 8 +-- go/vt/wrangler/traffic_switcher_env_test.go | 8 +-- 9 files changed, 106 insertions(+), 79 deletions(-) diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 509fccaa47a..b71b43a52c1 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -1429,7 +1429,7 @@ func TestListenerShutdown(t *testing.T) { l.Shutdown() - assert.EqualValues(t, 1, connRefuse.Get(), "connRefuse") + waitForConnRefuse(t, 1) err = conn.Ping() require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") @@ -1441,6 +1441,24 @@ func TestListenerShutdown(t *testing.T) { require.Equal(t, "Server shutdown in progress", sqlErr.Message) } +func waitForConnRefuse(t *testing.T, valWanted int64) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + tick := time.NewTicker(100 * time.Millisecond) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + require.FailNow(t, "connRefuse did not reach %v", valWanted) + case <-tick.C: + if connRefuse.Get() == valWanted { + return + } + } + } +} + func TestParseConnAttrs(t *testing.T) { expected := map[string]string{ "_client_version": "8.0.11", diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index 39ecca84156..91400d55e77 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -277,13 +277,6 @@ func (fmd *FakeMysqlDaemon) GetServerUUID(ctx context.Context) (string, error) { return "000000", nil } -// CurrentPrimaryPositionLocked is thread-safe -func (fmd *FakeMysqlDaemon) CurrentPrimaryPositionLocked(pos replication.Position) { - fmd.mu.Lock() - defer fmd.mu.Unlock() - fmd.CurrentPrimaryPosition = pos -} - // ReplicationStatus is part of the MysqlDaemon interface func (fmd *FakeMysqlDaemon) ReplicationStatus() (replication.ReplicationStatus, error) { if fmd.ReplicationStatusError != nil { @@ -307,6 +300,8 @@ func (fmd *FakeMysqlDaemon) ReplicationStatus() (replication.ReplicationStatus, // PrimaryStatus is part of the MysqlDaemon interface func (fmd *FakeMysqlDaemon) PrimaryStatus(ctx context.Context) (replication.PrimaryStatus, error) { + fmd.mu.Lock() + defer fmd.mu.Unlock() if fmd.PrimaryStatusError != nil { return replication.PrimaryStatus{}, fmd.PrimaryStatusError } @@ -372,7 +367,21 @@ func (fmd *FakeMysqlDaemon) GetPreviousGTIDs(ctx context.Context, binlog string) // PrimaryPosition is part of the MysqlDaemon interface func (fmd *FakeMysqlDaemon) PrimaryPosition() (replication.Position, error) { - return fmd.CurrentPrimaryPosition, nil + return fmd.GetPrimaryPositionLocked(), nil +} + +// GetPrimaryPositionLocked gets the primary position while holding the lock. +func (fmd *FakeMysqlDaemon) GetPrimaryPositionLocked() replication.Position { + fmd.mu.Lock() + defer fmd.mu.Unlock() + return fmd.CurrentPrimaryPosition +} + +// SetPrimaryPositionLocked is thread-safe. +func (fmd *FakeMysqlDaemon) SetPrimaryPositionLocked(pos replication.Position) { + fmd.mu.Lock() + defer fmd.mu.Unlock() + fmd.CurrentPrimaryPosition = pos } // IsReadOnly is part of the MysqlDaemon interface diff --git a/go/vt/vtctl/grpcvtctldserver/server_test.go b/go/vt/vtctl/grpcvtctldserver/server_test.go index 9410f769c85..46e39053648 100644 --- a/go/vt/vtctl/grpcvtctldserver/server_test.go +++ b/go/vt/vtctl/grpcvtctldserver/server_test.go @@ -12030,7 +12030,7 @@ func TestTabletExternallyReparented(t *testing.T) { defer func() { topofactory.SetError(nil) - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() resp, err := vtctld.GetTablets(ctx, &vtctldatapb.GetTabletsRequest{}) diff --git a/go/vt/vttablet/tabletmanager/framework_test.go b/go/vt/vttablet/tabletmanager/framework_test.go index 4734ab9ee96..24935912b42 100644 --- a/go/vt/vttablet/tabletmanager/framework_test.go +++ b/go/vt/vttablet/tabletmanager/framework_test.go @@ -111,8 +111,8 @@ func newTestEnv(t *testing.T, ctx context.Context, sourceKeyspace string, source tmclienttest.SetProtocol(fmt.Sprintf("go.vt.vttablet.tabletmanager.framework_test_%s", t.Name()), tenv.protoName) tenv.mysqld = mysqlctl.NewFakeMysqlDaemon(fakesqldb.New(t)) - var err error - tenv.mysqld.CurrentPrimaryPosition, err = replication.ParsePosition(gtidFlavor, gtidPosition) + curPosition, err := replication.ParsePosition(gtidFlavor, gtidPosition) + tenv.mysqld.SetPrimaryPositionLocked(curPosition) require.NoError(t, err) return tenv diff --git a/go/vt/wrangler/testlib/backup_test.go b/go/vt/wrangler/testlib/backup_test.go index 787e4ce1946..751da7fe936 100644 --- a/go/vt/wrangler/testlib/backup_test.go +++ b/go/vt/wrangler/testlib/backup_test.go @@ -148,7 +148,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { primary := NewFakeTablet(t, wr, "cell1", 0, topodatapb.TabletType_PRIMARY, db) primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -156,7 +156,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { Sequence: 457, }, }, - } + }) // start primary so that replica can fetch primary position from it primary.StartActionLoop(t, wr) @@ -168,7 +168,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { sourceTablet.FakeMysqlDaemon.ReadOnly = true sourceTablet.FakeMysqlDaemon.Replicating = true sourceTablet.FakeMysqlDaemon.SetReplicationSourceInputs = []string{fmt.Sprintf("%s:%d", primary.Tablet.MysqlHostname, primary.Tablet.MysqlPort)} - sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + sourceTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -176,7 +176,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { Sequence: 457, }, }, - } + }) sourceTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup "STOP SLAVE", @@ -221,7 +221,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { destTablet := NewFakeTablet(t, wr, "cell1", 2, topodatapb.TabletType_REPLICA, db) destTablet.FakeMysqlDaemon.ReadOnly = true destTablet.FakeMysqlDaemon.Replicating = true - destTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + destTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -229,7 +229,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { Sequence: 457, }, }, - } + }) destTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup "STOP SLAVE", @@ -247,7 +247,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { "RESET MASTER": {}, "SET GLOBAL gtid_purged": {}, } - destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition + destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.GetPrimaryPositionLocked() destTablet.FakeMysqlDaemon.SetReplicationSourceInputs = append(destTablet.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(primary.Tablet)) destTablet.StartActionLoop(t, wr) @@ -299,7 +299,7 @@ func testBackupRestore(t *testing.T, cDetails *compressionDetails) error { "START SLAVE", } - primary.FakeMysqlDaemon.SetReplicationPositionPos = primary.FakeMysqlDaemon.CurrentPrimaryPosition + primary.FakeMysqlDaemon.SetReplicationPositionPos = primary.FakeMysqlDaemon.GetPrimaryPositionLocked() // restore primary from latest backup require.NoError(t, primary.TM.RestoreData(ctx, logutil.NewConsoleLogger(), 0 /* waitForBackupInterval */, false /* deleteBeforeRestore */, time.Time{} /* restoreFromBackupTs */, time.Time{} /* restoreToTimestamp */, ""), @@ -386,7 +386,7 @@ func TestBackupRestoreLagged(t *testing.T) { primary := NewFakeTablet(t, wr, "cell1", 0, topodatapb.TabletType_PRIMARY, db) primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -394,7 +394,7 @@ func TestBackupRestoreLagged(t *testing.T) { Sequence: 457, }, }, - } + }) // start primary so that replica can fetch primary position from it primary.StartActionLoop(t, wr) @@ -405,7 +405,7 @@ func TestBackupRestoreLagged(t *testing.T) { sourceTablet := NewFakeTablet(t, wr, "cell1", 1, topodatapb.TabletType_REPLICA, db) sourceTablet.FakeMysqlDaemon.ReadOnly = true sourceTablet.FakeMysqlDaemon.Replicating = true - sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + sourceTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -413,7 +413,7 @@ func TestBackupRestoreLagged(t *testing.T) { Sequence: 456, }, }, - } + }) sourceTablet.FakeMysqlDaemon.SetReplicationSourceInputs = []string{fmt.Sprintf("%s:%d", primary.Tablet.MysqlHostname, primary.Tablet.MysqlPort)} sourceTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup @@ -447,7 +447,7 @@ func TestBackupRestoreLagged(t *testing.T) { timer := time.NewTicker(1 * time.Second) <-timer.C - sourceTablet.FakeMysqlDaemon.CurrentPrimaryPositionLocked(replication.Position{ + sourceTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -466,7 +466,7 @@ func TestBackupRestoreLagged(t *testing.T) { require.NoError(t, sourceTablet.FakeMysqlDaemon.CheckSuperQueryList()) assert.True(t, sourceTablet.FakeMysqlDaemon.Replicating) assert.True(t, sourceTablet.FakeMysqlDaemon.Running) - assert.Equal(t, primary.FakeMysqlDaemon.CurrentPrimaryPosition, sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition) + assert.Equal(t, primary.FakeMysqlDaemon.GetPrimaryPositionLocked(), sourceTablet.FakeMysqlDaemon.GetPrimaryPositionLocked()) case <-timer2.C: require.FailNow(t, "Backup timed out") } @@ -475,7 +475,7 @@ func TestBackupRestoreLagged(t *testing.T) { destTablet := NewFakeTablet(t, wr, "cell1", 2, topodatapb.TabletType_REPLICA, db) destTablet.FakeMysqlDaemon.ReadOnly = true destTablet.FakeMysqlDaemon.Replicating = true - destTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + destTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -483,7 +483,7 @@ func TestBackupRestoreLagged(t *testing.T) { Sequence: 456, }, }, - } + }) destTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup "STOP SLAVE", @@ -501,7 +501,7 @@ func TestBackupRestoreLagged(t *testing.T) { "RESET MASTER": {}, "SET GLOBAL gtid_purged": {}, } - destTablet.FakeMysqlDaemon.SetReplicationPositionPos = destTablet.FakeMysqlDaemon.CurrentPrimaryPosition + destTablet.FakeMysqlDaemon.SetReplicationPositionPos = destTablet.FakeMysqlDaemon.GetPrimaryPositionLocked() destTablet.FakeMysqlDaemon.SetReplicationSourceInputs = append(destTablet.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(primary.Tablet)) destTablet.StartActionLoop(t, wr) @@ -524,7 +524,7 @@ func TestBackupRestoreLagged(t *testing.T) { timer = time.NewTicker(1 * time.Second) <-timer.C - destTablet.FakeMysqlDaemon.CurrentPrimaryPositionLocked(replication.Position{ + destTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -542,7 +542,7 @@ func TestBackupRestoreLagged(t *testing.T) { require.NoError(t, destTablet.FakeMysqlDaemon.CheckSuperQueryList(), "destTablet.FakeMysqlDaemon.CheckSuperQueryList failed") assert.True(t, destTablet.FakeMysqlDaemon.Replicating) assert.True(t, destTablet.FakeMysqlDaemon.Running) - assert.Equal(t, primary.FakeMysqlDaemon.CurrentPrimaryPosition, destTablet.FakeMysqlDaemon.CurrentPrimaryPosition) + assert.Equal(t, primary.FakeMysqlDaemon.GetPrimaryPositionLocked(), destTablet.FakeMysqlDaemon.GetPrimaryPositionLocked()) case <-timer2.C: require.FailNow(t, "Restore timed out") } @@ -605,7 +605,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { primary := NewFakeTablet(t, wr, "cell1", 0, topodatapb.TabletType_PRIMARY, db) primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -613,7 +613,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { Sequence: 457, }, }, - } + }) // start primary so that replica can fetch primary position from it primary.StartActionLoop(t, wr) @@ -623,7 +623,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { sourceTablet := NewFakeTablet(t, wr, "cell1", 1, topodatapb.TabletType_REPLICA, db) sourceTablet.FakeMysqlDaemon.ReadOnly = true sourceTablet.FakeMysqlDaemon.Replicating = true - sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + sourceTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -631,7 +631,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { Sequence: 457, }, }, - } + }) sourceTablet.FakeMysqlDaemon.SetReplicationSourceInputs = []string{fmt.Sprintf("%s:%d", primary.Tablet.MysqlHostname, primary.Tablet.MysqlPort)} sourceTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup @@ -665,7 +665,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { destTablet := NewFakeTablet(t, wr, "cell1", 2, topodatapb.TabletType_REPLICA, db) destTablet.FakeMysqlDaemon.ReadOnly = true destTablet.FakeMysqlDaemon.Replicating = true - destTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + destTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -673,7 +673,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { Sequence: 457, }, }, - } + }) destTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup "STOP SLAVE", @@ -691,7 +691,7 @@ func TestRestoreUnreachablePrimary(t *testing.T) { "RESET MASTER": {}, "SET GLOBAL gtid_purged": {}, } - destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition + destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.GetPrimaryPositionLocked() destTablet.FakeMysqlDaemon.SetReplicationSourceInputs = append(destTablet.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(primary.Tablet)) destTablet.StartActionLoop(t, wr) @@ -780,7 +780,7 @@ func TestDisableActiveReparents(t *testing.T) { primary := NewFakeTablet(t, wr, "cell1", 0, topodatapb.TabletType_PRIMARY, db) primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -788,7 +788,7 @@ func TestDisableActiveReparents(t *testing.T) { Sequence: 457, }, }, - } + }) // start primary so that replica can fetch primary position from it primary.StartActionLoop(t, wr) @@ -799,7 +799,7 @@ func TestDisableActiveReparents(t *testing.T) { sourceTablet := NewFakeTablet(t, wr, "cell1", 1, topodatapb.TabletType_REPLICA, db) sourceTablet.FakeMysqlDaemon.ReadOnly = true sourceTablet.FakeMysqlDaemon.Replicating = true - sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + sourceTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -807,7 +807,7 @@ func TestDisableActiveReparents(t *testing.T) { Sequence: 457, }, }, - } + }) sourceTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "STOP SLAVE", } @@ -832,7 +832,7 @@ func TestDisableActiveReparents(t *testing.T) { destTablet := NewFakeTablet(t, wr, "cell1", 2, topodatapb.TabletType_REPLICA, db) destTablet.FakeMysqlDaemon.ReadOnly = true destTablet.FakeMysqlDaemon.Replicating = true - destTablet.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + destTablet.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -840,7 +840,7 @@ func TestDisableActiveReparents(t *testing.T) { Sequence: 457, }, }, - } + }) destTablet.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "STOP SLAVE", "RESET SLAVE ALL", @@ -851,7 +851,7 @@ func TestDisableActiveReparents(t *testing.T) { "RESET MASTER": {}, "SET GLOBAL gtid_purged": {}, } - destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.CurrentPrimaryPosition + destTablet.FakeMysqlDaemon.SetReplicationPositionPos = sourceTablet.FakeMysqlDaemon.GetPrimaryPositionLocked() destTablet.FakeMysqlDaemon.SetReplicationSourceInputs = append(destTablet.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(primary.Tablet)) destTablet.StartActionLoop(t, wr) diff --git a/go/vt/wrangler/testlib/emergency_reparent_shard_test.go b/go/vt/wrangler/testlib/emergency_reparent_shard_test.go index 99cc1839186..a901fce7856 100644 --- a/go/vt/wrangler/testlib/emergency_reparent_shard_test.go +++ b/go/vt/wrangler/testlib/emergency_reparent_shard_test.go @@ -62,7 +62,7 @@ func TestEmergencyReparentShard(t *testing.T) { reparenttestutil.SetKeyspaceDurability(context.Background(), t, ts, "test_keyspace", "semi_sync") oldPrimary.FakeMysqlDaemon.Replicating = false - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -70,7 +70,7 @@ func TestEmergencyReparentShard(t *testing.T) { Sequence: 456, }, }, - } + }) currentPrimaryFilePosition, _ := replication.ParseFilePosGTIDSet("mariadb-bin.000010:456") oldPrimary.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: currentPrimaryFilePosition, @@ -79,7 +79,7 @@ func TestEmergencyReparentShard(t *testing.T) { // new primary newPrimary.FakeMysqlDaemon.ReadOnly = true newPrimary.FakeMysqlDaemon.Replicating = true - newPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + newPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -87,7 +87,7 @@ func TestEmergencyReparentShard(t *testing.T) { Sequence: 456, }, }, - } + }) newPrimaryRelayLogPos, _ := replication.ParseFilePosGTIDSet("relay-bin.000004:456") newPrimary.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: newPrimaryRelayLogPos, @@ -122,7 +122,7 @@ func TestEmergencyReparentShard(t *testing.T) { // good replica 1 is replicating goodReplica1.FakeMysqlDaemon.ReadOnly = true goodReplica1.FakeMysqlDaemon.Replicating = true - goodReplica1.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + goodReplica1.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -130,7 +130,7 @@ func TestEmergencyReparentShard(t *testing.T) { Sequence: 455, }, }, - } + }) goodReplica1RelayLogPos, _ := replication.ParseFilePosGTIDSet("relay-bin.000004:455") goodReplica1.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: goodReplica1RelayLogPos, @@ -153,7 +153,7 @@ func TestEmergencyReparentShard(t *testing.T) { // good replica 2 is not replicating goodReplica2.FakeMysqlDaemon.ReadOnly = true goodReplica2.FakeMysqlDaemon.Replicating = false - goodReplica2.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + goodReplica2.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -161,7 +161,7 @@ func TestEmergencyReparentShard(t *testing.T) { Sequence: 454, }, }, - } + }) goodReplica2RelayLogPos, _ := replication.ParseFilePosGTIDSet("relay-bin.000004:454") goodReplica2.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: goodReplica2RelayLogPos, @@ -216,7 +216,7 @@ func TestEmergencyReparentShardPrimaryElectNotBest(t *testing.T) { newPrimary.FakeMysqlDaemon.Replicating = true // It has transactions in its relay log, but not as many as // moreAdvancedReplica - newPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + newPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -224,7 +224,7 @@ func TestEmergencyReparentShardPrimaryElectNotBest(t *testing.T) { Sequence: 456, }, }, - } + }) newPrimaryRelayLogPos, _ := replication.ParseFilePosGTIDSet("relay-bin.000004:456") newPrimary.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: newPrimaryRelayLogPos, @@ -249,7 +249,7 @@ func TestEmergencyReparentShardPrimaryElectNotBest(t *testing.T) { // more advanced replica moreAdvancedReplica.FakeMysqlDaemon.Replicating = true // relay log position is more advanced than desired new primary - moreAdvancedReplica.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + moreAdvancedReplica.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 2: replication.MariadbGTID{ Domain: 2, @@ -257,14 +257,14 @@ func TestEmergencyReparentShardPrimaryElectNotBest(t *testing.T) { Sequence: 457, }, }, - } + }) moreAdvancedReplicaLogPos, _ := replication.ParseFilePosGTIDSet("relay-bin.000004:457") moreAdvancedReplica.FakeMysqlDaemon.CurrentSourceFilePosition = replication.Position{ GTIDSet: moreAdvancedReplicaLogPos, } moreAdvancedReplica.FakeMysqlDaemon.SetReplicationSourceInputs = append(moreAdvancedReplica.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet), topoproto.MysqlAddr(oldPrimary.Tablet)) moreAdvancedReplica.FakeMysqlDaemon.WaitPrimaryPositions = append(moreAdvancedReplica.FakeMysqlDaemon.WaitPrimaryPositions, moreAdvancedReplica.FakeMysqlDaemon.CurrentSourceFilePosition) - newPrimary.FakeMysqlDaemon.WaitPrimaryPositions = append(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions, moreAdvancedReplica.FakeMysqlDaemon.CurrentPrimaryPosition) + newPrimary.FakeMysqlDaemon.WaitPrimaryPositions = append(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions, moreAdvancedReplica.FakeMysqlDaemon.GetPrimaryPositionLocked()) moreAdvancedReplica.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ // These 3 statements come from tablet startup "STOP SLAVE", diff --git a/go/vt/wrangler/testlib/planned_reparent_shard_test.go b/go/vt/wrangler/testlib/planned_reparent_shard_test.go index 0125e69cac0..64d85906e07 100644 --- a/go/vt/wrangler/testlib/planned_reparent_shard_test.go +++ b/go/vt/wrangler/testlib/planned_reparent_shard_test.go @@ -95,7 +95,7 @@ func TestPlannedReparentShardNoPrimaryProvided(t *testing.T) { oldPrimary.FakeMysqlDaemon.ReadOnly = false oldPrimary.FakeMysqlDaemon.Replicating = false oldPrimary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0] + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0]) oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs = append(oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet)) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "FAKE SET MASTER", @@ -212,7 +212,7 @@ func TestPlannedReparentShardNoError(t *testing.T) { oldPrimary.FakeMysqlDaemon.ReadOnly = false oldPrimary.FakeMysqlDaemon.Replicating = false oldPrimary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0] + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0]) oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs = append(oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet)) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "FAKE SET MASTER", @@ -433,7 +433,7 @@ func TestPlannedReparentShardWaitForPositionFail(t *testing.T) { oldPrimary.FakeMysqlDaemon.ReadOnly = false oldPrimary.FakeMysqlDaemon.Replicating = false // set to incorrect value to make promote fail on WaitForReplicationPos - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.PromoteResult + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.PromoteResult) oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs = append(oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet)) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "FAKE SET MASTER", @@ -541,7 +541,7 @@ func TestPlannedReparentShardWaitForPositionTimeout(t *testing.T) { // old primary oldPrimary.FakeMysqlDaemon.ReadOnly = false oldPrimary.FakeMysqlDaemon.Replicating = false - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0] + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0]) oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs = append(oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet)) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "FAKE SET MASTER", @@ -615,7 +615,7 @@ func TestPlannedReparentShardRelayLogError(t *testing.T) { primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false primary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 7: replication.MariadbGTID{ Domain: 7, @@ -623,7 +623,7 @@ func TestPlannedReparentShardRelayLogError(t *testing.T) { Sequence: 990, }, }, - } + }) primary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "SUBINSERT INTO _vt.reparent_journal (time_created_ns, action_name, primary_alias, replication_position) VALUES", } @@ -696,7 +696,7 @@ func TestPlannedReparentShardRelayLogErrorStartReplication(t *testing.T) { primary.FakeMysqlDaemon.ReadOnly = false primary.FakeMysqlDaemon.Replicating = false primary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 7: replication.MariadbGTID{ Domain: 7, @@ -704,7 +704,7 @@ func TestPlannedReparentShardRelayLogErrorStartReplication(t *testing.T) { Sequence: 990, }, }, - } + }) primary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "SUBINSERT INTO _vt.reparent_journal (time_created_ns, action_name, primary_alias, replication_position) VALUES", } @@ -814,7 +814,7 @@ func TestPlannedReparentShardPromoteReplicaFail(t *testing.T) { oldPrimary.FakeMysqlDaemon.ReadOnly = false oldPrimary.FakeMysqlDaemon.Replicating = false oldPrimary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0] + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0]) oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs = append(oldPrimary.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(newPrimary.Tablet)) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "FAKE SET MASTER", @@ -822,7 +822,7 @@ func TestPlannedReparentShardPromoteReplicaFail(t *testing.T) { // We call a SetReplicationSource explicitly "FAKE SET MASTER", "START SLAVE", - // extra SetReplicationSource call due to retry + // extra SetReplicationSource call due to retry) "FAKE SET MASTER", "START SLAVE", } @@ -884,7 +884,7 @@ func TestPlannedReparentShardPromoteReplicaFail(t *testing.T) { // retrying should work newPrimary.FakeMysqlDaemon.PromoteError = nil - newPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0] + newPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(newPrimary.FakeMysqlDaemon.WaitPrimaryPositions[0]) // run PlannedReparentShard err = vp.Run([]string{"PlannedReparentShard", "--wait_replicas_timeout", "10s", "--keyspace_shard", newPrimary.Tablet.Keyspace + "/" + newPrimary.Tablet.Shard, "--new_primary", topoproto.TabletAliasString(newPrimary.Tablet.Alias)}) @@ -921,7 +921,7 @@ func TestPlannedReparentShardSamePrimary(t *testing.T) { oldPrimary.FakeMysqlDaemon.ReadOnly = true oldPrimary.FakeMysqlDaemon.Replicating = false oldPrimary.FakeMysqlDaemon.ReplicationStatusError = mysql.ErrNotReplica - oldPrimary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + oldPrimary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 7: replication.MariadbGTID{ Domain: 7, @@ -929,7 +929,7 @@ func TestPlannedReparentShardSamePrimary(t *testing.T) { Sequence: 990, }, }, - } + }) oldPrimary.FakeMysqlDaemon.ExpectedExecuteSuperQueryList = []string{ "SUBINSERT INTO _vt.reparent_journal (time_created_ns, action_name, primary_alias, replication_position) VALUES", } diff --git a/go/vt/wrangler/testlib/reparent_utils_test.go b/go/vt/wrangler/testlib/reparent_utils_test.go index 0d1d84e89f5..faff354de4d 100644 --- a/go/vt/wrangler/testlib/reparent_utils_test.go +++ b/go/vt/wrangler/testlib/reparent_utils_test.go @@ -69,7 +69,7 @@ func TestShardReplicationStatuses(t *testing.T) { } // primary action loop (to initialize host and port) - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 5: replication.MariadbGTID{ Domain: 5, @@ -77,12 +77,12 @@ func TestShardReplicationStatuses(t *testing.T) { Sequence: 892, }, }, - } + }) primary.StartActionLoop(t, wr) defer primary.StopActionLoop(t) // replica loop - replica.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + replica.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 5: replication.MariadbGTID{ Domain: 5, @@ -90,7 +90,7 @@ func TestShardReplicationStatuses(t *testing.T) { Sequence: 890, }, }, - } + }) replica.FakeMysqlDaemon.CurrentSourceHost = primary.Tablet.MysqlHostname replica.FakeMysqlDaemon.CurrentSourcePort = primary.Tablet.MysqlPort replica.FakeMysqlDaemon.SetReplicationSourceInputs = append(replica.FakeMysqlDaemon.SetReplicationSourceInputs, topoproto.MysqlAddr(primary.Tablet)) diff --git a/go/vt/wrangler/traffic_switcher_env_test.go b/go/vt/wrangler/traffic_switcher_env_test.go index 73fcb76eb65..905edef9e97 100644 --- a/go/vt/wrangler/traffic_switcher_env_test.go +++ b/go/vt/wrangler/traffic_switcher_env_test.go @@ -754,7 +754,7 @@ func (tme *testMigraterEnv) createDBClients(ctx context.Context, t *testing.T) { func (tme *testMigraterEnv) setPrimaryPositions() { for _, primary := range tme.sourcePrimaries { - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 5: replication.MariadbGTID{ Domain: 5, @@ -762,10 +762,10 @@ func (tme *testMigraterEnv) setPrimaryPositions() { Sequence: 892, }, }, - } + }) } for _, primary := range tme.targetPrimaries { - primary.FakeMysqlDaemon.CurrentPrimaryPosition = replication.Position{ + primary.FakeMysqlDaemon.SetPrimaryPositionLocked(replication.Position{ GTIDSet: replication.MariadbGTIDSet{ 5: replication.MariadbGTID{ Domain: 5, @@ -773,7 +773,7 @@ func (tme *testMigraterEnv) setPrimaryPositions() { Sequence: 893, }, }, - } + }) } } From b7adf46d635c2fced62018bb4f26cf4608f413dd Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:06:48 +0200 Subject: [PATCH 2/5] [release-18.0] fixes bugs around expression precedence and LIKE (#16934 & #16649) (#16944) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Andres Taylor Signed-off-by: Manan Gupta Co-authored-by: Andrés Taylor Co-authored-by: Manan Gupta Co-authored-by: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> --- .../expressions/expressions.test | 30 ++++++++++++++++ go/vt/sqlparser/analyzer.go | 2 +- go/vt/sqlparser/ast_funcs.go | 4 --- go/vt/sqlparser/constants.go | 26 ++++++-------- go/vt/sqlparser/normalizer_test.go | 8 +++++ go/vt/sqlparser/parse_test.go | 8 +++-- go/vt/sqlparser/precedence.go | 8 ++--- go/vt/sqlparser/precedence_test.go | 3 ++ go/vt/sqlparser/sql.go | 4 +-- go/vt/sqlparser/sql.y | 4 +-- go/vt/sqlparser/tracked_buffer_test.go | 2 +- go/vt/vtgate/evalengine/expr.go | 2 +- go/vt/vtgate/evalengine/expr_compare.go | 34 +++++++++---------- go/vt/vtgate/evalengine/testcases/cases.go | 16 +++++---- go/vt/vtgate/evalengine/translate.go | 6 +--- go/vt/vtgate/executor_test.go | 2 +- .../planbuilder/testdata/select_cases.json | 4 +-- go/vt/vtgate/semantics/early_rewriter_test.go | 3 ++ 18 files changed, 98 insertions(+), 68 deletions(-) create mode 100644 go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test diff --git a/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test b/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test new file mode 100644 index 00000000000..60c1e641463 --- /dev/null +++ b/go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test @@ -0,0 +1,30 @@ +# This file contains queries that test expressions in Vitess. +# We've found a number of bugs around precedences that we want to test. +CREATE TABLE t0 +( + c1 BIT, + INDEX idx_c1 (c1) +); + +INSERT INTO t0(c1) +VALUES (''); + + +SELECT * +FROM t0; + +SELECT ((t0.c1 = 'a')) +FROM t0; + +SELECT * +FROM t0 +WHERE ((t0.c1 = 'a')); + + +SELECT (1 LIKE ('a' IS NULL)); +SELECT (NOT (1 LIKE ('a' IS NULL))); + +SELECT (~ (1 || 0)) IS NULL; + +SELECT 1 +WHERE (~ (1 || 0)) IS NULL; diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index f3b1fc10340..905d4deaa08 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -137,7 +137,7 @@ func ASTToStatementType(stmt Statement) StatementType { // CanNormalize takes Statement and returns if the statement can be normalized. func CanNormalize(stmt Statement) bool { switch stmt.(type) { - case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter + case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream, *VExplainStmt: // TODO: we could merge this logic into ASTrewriter return true } return false diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 07c3421988b..290c57f9d97 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1435,10 +1435,6 @@ func (op BinaryExprOperator) ToString() string { return ShiftLeftStr case ShiftRightOp: return ShiftRightStr - case JSONExtractOp: - return JSONExtractOpStr - case JSONUnquoteExtractOp: - return JSONUnquoteExtractOpStr default: return "Unknown BinaryExprOperator" } diff --git a/go/vt/sqlparser/constants.go b/go/vt/sqlparser/constants.go index 1be3124aa24..bce0cb32cba 100644 --- a/go/vt/sqlparser/constants.go +++ b/go/vt/sqlparser/constants.go @@ -155,19 +155,17 @@ const ( IsNotFalseStr = "is not false" // BinaryExpr.Operator - BitAndStr = "&" - BitOrStr = "|" - BitXorStr = "^" - PlusStr = "+" - MinusStr = "-" - MultStr = "*" - DivStr = "/" - IntDivStr = "div" - ModStr = "%" - ShiftLeftStr = "<<" - ShiftRightStr = ">>" - JSONExtractOpStr = "->" - JSONUnquoteExtractOpStr = "->>" + BitAndStr = "&" + BitOrStr = "|" + BitXorStr = "^" + PlusStr = "+" + MinusStr = "-" + MultStr = "*" + DivStr = "/" + IntDivStr = "div" + ModStr = "%" + ShiftLeftStr = "<<" + ShiftRightStr = ">>" // UnaryExpr.Operator UPlusStr = "+" @@ -718,8 +716,6 @@ const ( ModOp ShiftLeftOp ShiftRightOp - JSONExtractOp - JSONUnquoteExtractOp ) // Constant for Enum Type - UnaryExprOperator diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index a65913b6a57..3097e0161bc 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -405,6 +405,13 @@ func TestNormalize(t *testing.T) { "bv2": sqltypes.Int64BindVariable(2), "bv3": sqltypes.TestBindVariable([]any{1, 2}), }, + }, { + in: "SELECT 1 WHERE (~ (1||0)) IS NULL", + outstmt: "select :bv1 /* INT64 */ from dual where ~(:bv1 /* INT64 */ or :bv2 /* INT64 */) is null", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(0), + }, }} for _, tc := range testcases { t.Run(tc.in, func(t *testing.T) { @@ -491,6 +498,7 @@ func TestNormalizeOneCasae(t *testing.T) { err = Normalize(tree, NewReservedVars("vtg", known), bv) require.NoError(t, err) normalizerOutput := String(tree) + require.EqualValues(t, testOne.output, normalizerOutput) if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { return } diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index aa007febe9a..fa189def53c 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -994,9 +994,11 @@ var ( }, { input: "select /* u~ */ 1 from t where a = ~b", }, { - input: "select /* -> */ a.b -> 'ab' from t", + input: "select /* -> */ a.b -> 'ab' from t", + output: "select /* -> */ json_extract(a.b, 'ab') from t", }, { - input: "select /* -> */ a.b ->> 'ab' from t", + input: "select /* -> */ a.b ->> 'ab' from t", + output: "select /* -> */ json_unquote(json_extract(a.b, 'ab')) from t", }, { input: "select /* empty function */ 1 from t where a = b()", }, { @@ -5675,7 +5677,7 @@ partition by range (YEAR(purchased)) subpartition by hash (TO_DAYS(purchased)) }, { input: "create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned ARRAY))))", - output: "create table t (\n\tid int,\n\tinfo JSON,\n\tINDEX zips ((cast(info -> '$.field' as unsigned array)))\n)", + output: "create table t (\n\tid int,\n\tinfo JSON,\n\tINDEX zips ((cast(json_extract(info, '$.field') as unsigned array)))\n)", }, } for _, test := range createTableQueries { diff --git a/go/vt/sqlparser/precedence.go b/go/vt/sqlparser/precedence.go index ec590b23f95..1b5576f65b1 100644 --- a/go/vt/sqlparser/precedence.go +++ b/go/vt/sqlparser/precedence.go @@ -38,7 +38,6 @@ const ( P14 P15 P16 - P17 ) // precedenceFor returns the precedence of an expression. @@ -58,10 +57,7 @@ func precedenceFor(in Expr) Precendence { case *BetweenExpr: return P12 case *ComparisonExpr: - switch node.Operator { - case EqualOp, NotEqualOp, GreaterThanOp, GreaterEqualOp, LessThanOp, LessEqualOp, LikeOp, InOp, RegexpOp, NullSafeEqualOp: - return P11 - } + return P11 case *IsExpr: return P11 case *BinaryExpr: @@ -83,7 +79,7 @@ func precedenceFor(in Expr) Precendence { switch node.Operator { case UPlusOp, UMinusOp: return P4 - case BangOp: + default: return P3 } } diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index a6cbffee351..3905212ff78 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -156,6 +156,9 @@ func TestParens(t *testing.T) { {in: "(10 - 2) - 1", expected: "10 - 2 - 1"}, {in: "10 - (2 - 1)", expected: "10 - (2 - 1)"}, {in: "0 <=> (1 and 0)", expected: "0 <=> (1 and 0)"}, + {in: "(~ (1||0)) IS NULL", expected: "~(1 or 0) is null"}, + {in: "1 not like ('a' is null)", expected: "1 not like ('a' is null)"}, + {in: ":vtg1 not like (:vtg2 is null)", expected: ":vtg1 not like (:vtg2 is null)"}, } for _, tc := range tests { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 51d3c77cd92..9ea4d0bc595 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -17840,7 +17840,7 @@ yydefault: var yyLOCAL Expr //line sql.y:5487 { - yyLOCAL = &BinaryExpr{Left: yyDollar[1].exprUnion(), Operator: JSONExtractOp, Right: yyDollar[3].exprUnion()} + yyLOCAL = &JSONExtractExpr{JSONDoc: yyDollar[1].exprUnion(), PathList: []Expr{yyDollar[3].exprUnion()}} } yyVAL.union = yyLOCAL case 1063: @@ -17848,7 +17848,7 @@ yydefault: var yyLOCAL Expr //line sql.y:5491 { - yyLOCAL = &BinaryExpr{Left: yyDollar[1].exprUnion(), Operator: JSONUnquoteExtractOp, Right: yyDollar[3].exprUnion()} + yyLOCAL = &JSONUnquoteExpr{JSONValue: &JSONExtractExpr{JSONDoc: yyDollar[1].exprUnion(), PathList: []Expr{yyDollar[3].exprUnion()}}} } yyVAL.union = yyLOCAL case 1064: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index a3f7a2e1a82..9cdd14cfdcf 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -5485,11 +5485,11 @@ function_call_keyword } | column_name_or_offset JSON_EXTRACT_OP text_literal_or_arg { - $$ = &BinaryExpr{Left: $1, Operator: JSONExtractOp, Right: $3} + $$ = &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}} } | column_name_or_offset JSON_UNQUOTE_EXTRACT_OP text_literal_or_arg { - $$ = &BinaryExpr{Left: $1, Operator: JSONUnquoteExtractOp, Right: $3} + $$ = &JSONUnquoteExpr{JSONValue: &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}}} } column_names_opt_paren: diff --git a/go/vt/sqlparser/tracked_buffer_test.go b/go/vt/sqlparser/tracked_buffer_test.go index 6924bf11911..378e386192f 100644 --- a/go/vt/sqlparser/tracked_buffer_test.go +++ b/go/vt/sqlparser/tracked_buffer_test.go @@ -270,7 +270,7 @@ func TestCanonicalOutput(t *testing.T) { }, { "create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned array))))", - "CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tINDEX `zips` ((CAST(`info` -> '$.field' AS unsigned array)))\n)", + "CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tINDEX `zips` ((CAST(JSON_EXTRACT(`info`, '$.field') AS unsigned array)))\n)", }, { "select 1 from t1 into outfile 'test/t1.txt'", diff --git a/go/vt/vtgate/evalengine/expr.go b/go/vt/vtgate/evalengine/expr.go index dfa8491391e..0be178ee149 100644 --- a/go/vt/vtgate/evalengine/expr.go +++ b/go/vt/vtgate/evalengine/expr.go @@ -56,7 +56,7 @@ func (expr *BinaryExpr) arguments(env *ExpressionEnv) (eval, eval, error) { } right, err := expr.Right.eval(env) if err != nil { - return nil, nil, err + return left, nil, err } return left, right, nil } diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index b723609160c..efc5fdee04a 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -578,13 +578,18 @@ func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool { } fullColl := colldata.Lookup(coll) wc := fullColl.Wildcard(right, 0, 0, 0) - return wc.Match(left) + return wc.Match(left) == !l.Negate } func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) { - left, right, err := l.arguments(env) - if left == nil || right == nil || err != nil { - return nil, err + left, err := l.Left.eval(env) + if err != nil || left == nil { + return left, err + } + + right, err := l.Right.eval(env) + if err != nil || right == nil { + return right, err } var col collations.TypedCollation @@ -593,18 +598,9 @@ func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) { return nil, err } - var matched bool - switch { - case typeIsTextual(left.SQLType()) && typeIsTextual(right.SQLType()): - matched = l.matchWildcard(left.(*evalBytes).bytes, right.(*evalBytes).bytes, col.Collation) - case typeIsTextual(right.SQLType()): - matched = l.matchWildcard(left.ToRawBytes(), right.(*evalBytes).bytes, col.Collation) - case typeIsTextual(left.SQLType()): - matched = l.matchWildcard(left.(*evalBytes).bytes, right.ToRawBytes(), col.Collation) - default: - matched = l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), collations.CollationBinaryID) - } - return newEvalBool(matched == !l.Negate), nil + matched := l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), col.Collation) + + return newEvalBool(matched), nil } // typeof implements the ComparisonOp interface @@ -620,12 +616,14 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) { return ctype{}, err } + skip1 := c.compileNullCheck1(lt) + rt, err := expr.Right.compile(c) if err != nil { return ctype{}, err } - skip := c.compileNullCheck2(lt, rt) + skip2 := c.compileNullCheck1(rt) if !lt.isTextual() { c.asm.Convert_xc(2, sqltypes.VarChar, c.cfg.Collation, nil) @@ -678,6 +676,6 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) { }) } - c.asm.jumpDestination(skip) + c.asm.jumpDestination(skip1, skip2) return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | flagNullable}, nil } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index cd52631c00c..d25415ae674 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -1060,24 +1060,26 @@ func CollationOperations(yield Query) { } func LikeComparison(yield Query) { - var left = []string{ + var left = append(inputConversions, `'foobar'`, `'FOOBAR'`, `'1234'`, `1234`, `_utf8mb4 'foobar' COLLATE utf8mb4_0900_as_cs`, - `_utf8mb4 'FOOBAR' COLLATE utf8mb4_0900_as_cs`, - } - var right = append([]string{ + `_utf8mb4 'FOOBAR' COLLATE utf8mb4_0900_as_cs`) + + var right = append(left, + `NULL`, `1`, `0`, `'foo%'`, `'FOO%'`, `'foo_ar'`, `'FOO_AR'`, `'12%'`, `'12_4'`, `_utf8mb4 'foo%' COLLATE utf8mb4_0900_as_cs`, `_utf8mb4 'FOO%' COLLATE utf8mb4_0900_as_cs`, `_utf8mb4 'foo_ar' COLLATE utf8mb4_0900_as_cs`, - `_utf8mb4 'FOO_AR' COLLATE utf8mb4_0900_as_cs`, - }, left...) + `_utf8mb4 'FOO_AR' COLLATE utf8mb4_0900_as_cs`) for _, lhs := range left { for _, rhs := range right { - yield(fmt.Sprintf("%s LIKE %s", lhs, rhs), nil) + for _, op := range []string{"LIKE", "NOT LIKE"} { + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + } } } } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 37a4037d21e..b379812c949 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -84,7 +84,7 @@ func (ast *astCompiler) translateComparisonExpr2(op sqlparser.ComparisonExprOper Negate: op == sqlparser.NotRegexpOp, }, nil default: - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, op.ToString()) + return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, op.ToString()) } } @@ -298,10 +298,6 @@ func (ast *astCompiler) translateBinaryExpr(binary *sqlparser.BinaryExpr) (Expr, return &BitwiseExpr{BinaryExpr: binaryExpr, Op: &opBitShl{}}, nil case sqlparser.ShiftRightOp: return &BitwiseExpr{BinaryExpr: binaryExpr, Op: &opBitShr{}}, nil - case sqlparser.JSONExtractOp: - return builtinJSONExtractRewrite(left, right) - case sqlparser.JSONUnquoteExtractOp: - return builtinJSONExtractUnquoteRewrite(left, right) default: return nil, translateExprNotSupported(binary) } diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 0b086021d88..2d3a82f0b9b 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -2290,7 +2290,7 @@ func TestExecutorVExplain(t *testing.T) { result, err = executorExec(ctx, executor, session, "vexplain plan select 42", bindVars) require.NoError(t, err) - expected := `[[VARCHAR("{\n\t\"OperatorType\": \"Projection\",\n\t\"Expressions\": [\n\t\t\"INT64(42) as 42\"\n\t],\n\t\"Inputs\": [\n\t\t{\n\t\t\t\"OperatorType\": \"SingleRow\"\n\t\t}\n\t]\n}")]]` + expected := `[[VARCHAR("{\n\t\"OperatorType\": \"Projection\",\n\t\"Expressions\": [\n\t\t\":vtg1 as :vtg1 /* INT64 */\"\n\t],\n\t\"Inputs\": [\n\t\t{\n\t\t\t\"OperatorType\": \"SingleRow\"\n\t\t}\n\t]\n}")]]` require.Equal(t, expected, fmt.Sprintf("%v", result.Rows)) } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index d8f0d09a64e..016cfdb0fd1 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2800,8 +2800,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a -> '$[4]', a ->> '$[3]' from `user` where 1 != 1", - "Query": "select a -> '$[4]', a ->> '$[3]' from `user`", + "FieldQuery": "select json_extract(a, '$[4]'), json_unquote(json_extract(a, '$[3]')) from `user` where 1 != 1", + "Query": "select json_extract(a, '$[4]'), json_unquote(json_extract(a, '$[3]')) from `user`", "Table": "`user`" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index ba5510c5bf5..13823e5351f 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -823,6 +823,9 @@ func TestRewriteNot(t *testing.T) { }, { sql: "select a from t1 where not a > 12", expected: "select a from t1 where a <= 12", + }, { + sql: "select (not (1 like ('a' is null)))", + expected: "select 1 not like ('a' is null) from dual", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { From ce4e22dd2055a190fee0f23891acbed011f829ef Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 13:24:02 +0530 Subject: [PATCH 3/5] [release-18.0] fix: route engine to handle column truncation for execute after lookup (#16981) (#16983) Signed-off-by: Harshit Gangal Co-authored-by: Harshit Gangal --- go/vt/vtgate/engine/route.go | 31 ++--- go/vt/vtgate/engine/vindex_lookup_test.go | 135 ++++++++++++++++++++++ 2 files changed, 146 insertions(+), 20 deletions(-) create mode 100644 go/vt/vtgate/engine/vindex_lookup_test.go diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 5604fc33ada..8c7e944efd7 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -166,11 +166,12 @@ func (route *Route) SetTruncateColumnCount(count int) { func (route *Route) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { ctx, cancelFunc := addQueryTimeout(ctx, vcursor, route.QueryTimeout) defer cancelFunc() - qr, err := route.executeInternal(ctx, vcursor, bindVars, wantfields) + rss, bvs, err := route.findRoute(ctx, vcursor, bindVars) if err != nil { return nil, err } - return qr.Truncate(route.TruncateColumnCount), nil + + return route.executeShards(ctx, vcursor, bindVars, wantfields, rss, bvs) } // addQueryTimeout adds a query timeout to the context it receives and returns the modified context along with the cancel function. @@ -188,20 +189,6 @@ const ( IgnoreReserveTxn cxtKey = iota ) -func (route *Route) executeInternal( - ctx context.Context, - vcursor VCursor, - bindVars map[string]*querypb.BindVariable, - wantfields bool, -) (*sqltypes.Result, error) { - rss, bvs, err := route.findRoute(ctx, vcursor, bindVars) - if err != nil { - return nil, err - } - - return route.executeShards(ctx, vcursor, bindVars, wantfields, rss, bvs) -} - func (route *Route) executeShards( ctx context.Context, vcursor VCursor, @@ -255,11 +242,15 @@ func (route *Route) executeShards( } } - if len(route.OrderBy) == 0 { - return result, nil + if len(route.OrderBy) != 0 { + var err error + result, err = route.sort(result) + if err != nil { + return nil, err + } } - return route.sort(result) + return result.Truncate(route.TruncateColumnCount), nil } func filterOutNilErrors(errs []error) []error { @@ -441,7 +432,7 @@ func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { return 0 }) - return out.Truncate(route.TruncateColumnCount), err + return out, err } func (route *Route) description() PrimitiveDescription { diff --git a/go/vt/vtgate/engine/vindex_lookup_test.go b/go/vt/vtgate/engine/vindex_lookup_test.go new file mode 100644 index 00000000000..156558bbfdc --- /dev/null +++ b/go/vt/vtgate/engine/vindex_lookup_test.go @@ -0,0 +1,135 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +var ( + vindex, _ = vindexes.CreateVindex("lookup_unique", "", map[string]string{ + "table": "lkp", + "from": "from", + "to": "toc", + "write_only": "true", + }) + ks = &vindexes.Keyspace{Name: "ks", Sharded: true} +) + +func TestVindexLookup(t *testing.T) { + planableVindex, ok := vindex.(vindexes.LookupPlanable) + require.True(t, ok, "not a lookup vindex") + _, args := planableVindex.Query() + + fp := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields("id|keyspace_id", "int64|varbinary"), + "1|\x10"), + }, + } + route := NewRoute(ByDestination, ks, "dummy_select", "dummy_select_field") + vdxLookup := &VindexLookup{ + Opcode: EqualUnique, + Keyspace: ks, + Vindex: planableVindex, + Arguments: args, + Values: []evalengine.Expr{evalengine.NewLiteralInt(1)}, + Lookup: fp, + SendTo: route, + } + + vc := &loggingVCursor{results: []*sqltypes.Result{defaultSelectResult}} + + result, err := vdxLookup.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + fp.ExpectLog(t, []string{`Execute from: type:TUPLE values:{type:INT64 value:"1"} false`}) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(10)`, + `ExecuteMultiShard ks.-20: dummy_select {} false false`, + }) + expectResult(t, "vindexlookup execute", result, defaultSelectResult) + + fp.rewind() + vc.Rewind() + result, err = wrapStreamExecute(vdxLookup, vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(10)`, + `StreamExecuteMulti dummy_select ks.-20: {} `, + }) + expectResult(t, "vindexlookup stream execute", result, defaultSelectResult) +} + +func TestVindexLookupTruncate(t *testing.T) { + planableVindex, ok := vindex.(vindexes.LookupPlanable) + require.True(t, ok, "not a lookup vindex") + _, args := planableVindex.Query() + + fp := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields("id|keyspace_id", "int64|varbinary"), + "1|\x10"), + }, + } + route := NewRoute(ByDestination, ks, "dummy_select", "dummy_select_field") + route.TruncateColumnCount = 1 + vdxLookup := &VindexLookup{ + Opcode: EqualUnique, + Keyspace: ks, + Vindex: planableVindex, + Arguments: args, + Values: []evalengine.Expr{evalengine.NewLiteralInt(1)}, + Lookup: fp, + SendTo: route, + } + + vc := &loggingVCursor{results: []*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("name|morecol", "varchar|int64"), + "foo|1", "bar|2", "baz|3"), + }} + + wantRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("name", "varchar"), + "foo", "bar", "baz") + result, err := vdxLookup.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + fp.ExpectLog(t, []string{`Execute from: type:TUPLE values:{type:INT64 value:"1"} false`}) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(10)`, + `ExecuteMultiShard ks.-20: dummy_select {} false false`, + }) + expectResult(t, "vindexlookup execute", result, wantRes) + + fp.rewind() + vc.Rewind() + result, err = wrapStreamExecute(vdxLookup, vc, map[string]*querypb.BindVariable{}, false) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(10)`, + `StreamExecuteMulti dummy_select ks.-20: {} `, + }) + expectResult(t, "vindexlookup stream execute", result, wantRes) +} From afb135d912521f249539ef9cf426bfeb00c00ec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 17 Oct 2024 10:06:44 +0200 Subject: [PATCH 4/5] [release-18.0] Planner Bug: Joins inside derived table (#14974) (#16962) Signed-off-by: Andres Taylor --- .../vtgate/queries/subquery/subquery_test.go | 13 ++ go/vt/vtgate/executor_select_test.go | 30 ++-- .../planbuilder/operators/SQL_builder.go | 17 +- .../planbuilder/operators/aggregator.go | 8 + .../planbuilder/operators/apply_join.go | 50 ++++-- .../planbuilder/operators/expressions.go | 1 + .../planbuilder/operators/join_merging.go | 10 +- .../planbuilder/operators/offset_planning.go | 6 + .../planbuilder/operators/operator_funcs.go | 103 ----------- .../planbuilder/operators/projection.go | 12 +- .../planbuilder/operators/query_planning.go | 163 +++++++++++++----- go/vt/vtgate/planbuilder/operators/route.go | 62 ++++--- .../planbuilder/operators/route_planning.go | 10 +- .../planbuilder/testdata/aggr_cases.json | 86 +++++---- .../planbuilder/testdata/from_cases.json | 10 +- .../testdata/large_union_cases.json | 8 +- .../testdata/postprocess_cases.json | 2 +- .../planbuilder/testdata/reference_cases.json | 14 +- .../planbuilder/testdata/select_cases.json | 8 +- .../planbuilder/testdata/tpcc_cases.json | 4 +- .../planbuilder/testdata/tpch_cases.json | 52 +++--- .../planbuilder/testdata/union_cases.json | 64 +++---- 22 files changed, 407 insertions(+), 326 deletions(-) delete mode 100644 go/vt/vtgate/planbuilder/operators/operator_funcs.go diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 0e21fa43cc1..a03948f8f40 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -173,3 +173,16 @@ func TestSubqueryInAggregation(t *testing.T) { // This fails as the planner adds `weight_string` method which make the query fail on MySQL. // mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`) } + +// TestSubqueryInDerivedTable tests that subqueries and derived tables +// are handled correctly when there are joins inside the derived table +func TestSubqueryInDerivedTable(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 18, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("INSERT INTO t1 (id1, id2) VALUES (1, 100), (2, 200), (3, 300), (4, 400), (5, 500);") + mcmp.Exec("INSERT INTO t2 (id3, id4) VALUES (10, 1), (20, 2), (30, 3), (40, 4), (50, 99)") + mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`) + mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`) +} diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index d92ca58ff3a..0130843cbb1 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2821,7 +2821,7 @@ func TestCrossShardSubquery(t *testing.T) { utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ - Sql: "select 1 from (select u2.id from `user` as u2 where u2.id = :u1_col) as t", + Sql: "select 1 from (select u2.id as id from `user` as u2 where u2.id = :u1_col) as t", BindVariables: map[string]*querypb.BindVariable{"u1_col": sqltypes.Int32BindVariable(3)}, }} utils.MustMatch(t, wantQueries, sbc2.Queries) @@ -2896,7 +2896,7 @@ func TestCrossShardSubqueryStream(t *testing.T) { }} utils.MustMatch(t, wantQueries, sbc1.Queries) wantQueries = []*querypb.BoundQuery{{ - Sql: "select 1 from (select u2.id from `user` as u2 where u2.id = :u1_col) as t", + Sql: "select 1 from (select u2.id as id from `user` as u2 where u2.id = :u1_col) as t", BindVariables: map[string]*querypb.BindVariable{"u1_col": sqltypes.Int32BindVariable(3)}, }} utils.MustMatch(t, wantQueries, sbc2.Queries) @@ -2937,7 +2937,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) { Sql: "select t.id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where 1 != 1) as t where 1 != 1", BindVariables: map[string]*querypb.BindVariable{}, }, { - Sql: "select 1 from (select u2.id from `user` as u2 where 1 != 1) as t where 1 != 1", + Sql: "select 1 from (select u2.id as id from `user` as u2 where 1 != 1) as t where 1 != 1", BindVariables: map[string]*querypb.BindVariable{ "u1_col": sqltypes.NullBindVariable, }, @@ -3847,14 +3847,14 @@ func TestSelectAggregationNoData(t *testing.T) { { sql: `select count(*) from (select col1, col2 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")), - expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(0)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")), - expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, expRow: `[]`, }, @@ -3939,70 +3939,70 @@ func TestSelectAggregationData(t *testing.T) { { sql: `select count(*) from (select col1, col2 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"), - expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"), - expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`, }, { sql: `select count(col1) from (select id, col1 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"), - expSandboxQ: "select id, col1 from (select id, col1 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"), - expSandboxQ: "select col2, col1, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`, expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`, expRow: `[[VARCHAR("a") DECIMAL(12)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(8)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") NULL]]`, }, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 9a55ddef1de..cd56fed05b2 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -21,6 +21,8 @@ import ( "slices" "sort" + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" @@ -562,6 +564,13 @@ func buildProjection(op *Projection, qb *queryBuilder) error { } func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error { + predicates := slice.Map(op.JoinPredicates, func(jc JoinColumn) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + qb.ctx.SkipPredicates[jc.RHSExpr] = nil + + return jc.Original.Expr + }) + pred := sqlparser.AndExpressions(predicates...) err := buildQuery(op.LHS, qb) if err != nil { return err @@ -569,8 +578,8 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error { // If we are going to add the predicate used in join here // We should not add the predicate's copy of when it was split into // two parts. To avoid this, we use the SkipPredicates map. - for _, expr := range qb.ctx.JoinPredicates[op.Predicate] { - qb.ctx.SkipPredicates[expr] = nil + for _, pred := range op.JoinPredicates { + qb.ctx.SkipPredicates[pred.RHSExpr] = nil } qbR := &queryBuilder{ctx: qb.ctx} err = buildQuery(op.RHS, qbR) @@ -578,9 +587,9 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error { return err } if op.LeftJoin { - qb.joinOuterWith(qbR, op.Predicate) + qb.joinOuterWith(qbR, pred) } else { - qb.joinInnerWith(qbR, op.Predicate) + qb.joinInnerWith(qbR, pred) } return nil } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 9e38048d957..2073dcd6119 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -124,6 +124,14 @@ func (a *Aggregator) isDerived() bool { return a.DT != nil } +func (a *Aggregator) derivedName() string { + if a.DT == nil { + return "" + } + + return a.DT.Alias +} + func (a *Aggregator) FindCol(ctx *plancontext.PlanningContext, in sqlparser.Expr, underRoute bool) (int, error) { if underRoute && a.isDerived() { // We don't want to use columns on this operator if it's a derived table under a route. diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 5e48fb4d5e3..279357d98ed 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -38,9 +38,6 @@ type ( // LeftJoin will be true in the case of an outer join LeftJoin bool - // Before offset planning - Predicate sqlparser.Expr - // JoinColumns keeps track of what AST expression is represented in the Columns array JoinColumns []JoinColumn @@ -86,14 +83,19 @@ type ( } ) -func NewApplyJoin(lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin { - return &ApplyJoin{ - LHS: lhs, - RHS: rhs, - Vars: map[string]int{}, - Predicate: predicate, - LeftJoin: leftOuterJoin, +func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) (*ApplyJoin, error) { + aj := &ApplyJoin{ + LHS: lhs, + RHS: rhs, + Vars: map[string]int{}, + LeftJoin: leftOuterJoin, + } + err := aj.AddJoinPredicate(ctx, predicate) + if err != nil { + return nil, err } + + return aj, nil } // Clone implements the Operator interface @@ -105,7 +107,6 @@ func (aj *ApplyJoin) Clone(inputs []ops.Operator) ops.Operator { kopy.JoinColumns = slices.Clone(aj.JoinColumns) kopy.JoinPredicates = slices.Clone(aj.JoinPredicates) kopy.Vars = maps.Clone(aj.Vars) - kopy.Predicate = sqlparser.CloneExpr(aj.Predicate) kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars) return &kopy } @@ -149,8 +150,9 @@ func (aj *ApplyJoin) IsInner() bool { } func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error { - aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate) - + if expr == nil { + return nil + } col, err := BreakExpressionInLHSandRHS(ctx, expr, TableID(aj.LHS)) if err != nil { return err @@ -312,11 +314,15 @@ func (aj *ApplyJoin) addOffset(offset int) { } func (aj *ApplyJoin) ShortDescription() string { - pred := sqlparser.String(aj.Predicate) - columns := slice.Map(aj.JoinColumns, func(from JoinColumn) string { - return sqlparser.String(from.Original) - }) - firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", ")) + fn := func(cols []JoinColumn) string { + out := slice.Map(cols, func(jc JoinColumn) string { + return jc.String() + }) + return strings.Join(out, ", ") + } + + firstPart := fmt.Sprintf("on %s columns: %s", fn(aj.JoinPredicates), fn(aj.JoinColumns)) + if len(aj.ExtraLHSVars) == 0 { return firstPart } @@ -419,6 +425,14 @@ func (jc JoinColumn) IsMixedLeftAndRight() bool { return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil } +func (jc JoinColumn) String() string { + rhs := sqlparser.String(jc.RHSExpr) + lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string { + return sqlparser.String(e.Expr) + }) + return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original)) +} + func (bve BindVarExpr) String() string { if bve.Name == "" { return sqlparser.String(bve.Expr) diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 7ab27e787e8..4c03490317e 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -29,6 +29,7 @@ func BreakExpressionInLHSandRHS( expr sqlparser.Expr, lhs semantics.TableSet, ) (col JoinColumn, err error) { + col.Original = aeWrap(expr) rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { nodeExpr, ok := cursor.Node().(sqlparser.Expr) if !ok || !fetchByOffset(nodeExpr) { diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index c43b7b4c87e..a8f401211e6 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -235,13 +235,17 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting return tr } -func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin { - return NewApplyJoin(op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin) +func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) (*ApplyJoin, error) { + return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin) } func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) (*Route, error) { + join, err := jm.getApplyJoin(ctx, op1, op2) + if err != nil { + return nil, err + } return &Route{ - Source: jm.getApplyJoin(ctx, op1, op2), + Source: join, MergedWith: []*Route{op2}, Routing: r, }, nil diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 502df5e9c82..db9bd8d3631 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -30,6 +30,7 @@ import ( // planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization, func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { type offsettable interface { + ops.Operator planOffsets(ctx *plancontext.PlanningContext) error } @@ -40,6 +41,11 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in)) case offsettable: err = op.planOffsets(ctx) + if rewrite.DebugOperatorTree { + fmt.Println("Planned offsets for:") + fmt.Println(ops.ToTree(op)) + } + } if err != nil { return nil, nil, err diff --git a/go/vt/vtgate/planbuilder/operators/operator_funcs.go b/go/vt/vtgate/planbuilder/operators/operator_funcs.go deleted file mode 100644 index 7f7aaff29c5..00000000000 --- a/go/vt/vtgate/planbuilder/operators/operator_funcs.go +++ /dev/null @@ -1,103 +0,0 @@ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package operators - -import ( - "fmt" - - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" -) - -// RemovePredicate is used when we turn a predicate into a plan operator, -// and the predicate needs to be removed as an AST construct -func RemovePredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op ops.Operator) (ops.Operator, error) { - switch op := op.(type) { - case *Route: - newSrc, err := RemovePredicate(ctx, expr, op.Source) - if err != nil { - return nil, err - } - op.Source = newSrc - return op, err - case *ApplyJoin: - isRemoved := false - deps := ctx.SemTable.RecursiveDeps(expr) - if deps.IsSolvedBy(TableID(op.LHS)) { - newSrc, err := RemovePredicate(ctx, expr, op.LHS) - if err != nil { - return nil, err - } - op.LHS = newSrc - isRemoved = true - } - - if deps.IsSolvedBy(TableID(op.RHS)) { - newSrc, err := RemovePredicate(ctx, expr, op.RHS) - if err != nil { - return nil, err - } - op.RHS = newSrc - isRemoved = true - } - - var keep []sqlparser.Expr - for _, e := range sqlparser.SplitAndExpression(nil, op.Predicate) { - if ctx.SemTable.EqualsExprWithDeps(expr, e) { - isRemoved = true - } else { - keep = append(keep, e) - } - } - - if !isRemoved { - return nil, vterrors.VT12001(fmt.Sprintf("remove '%s' predicate on cross-shard join query", sqlparser.String(expr))) - } - - op.Predicate = ctx.SemTable.AndExpressions(keep...) - return op, nil - case *Filter: - idx := -1 - for i, predicate := range op.Predicates { - if ctx.SemTable.EqualsExprWithDeps(predicate, expr) { - idx = i - } - } - if idx == -1 { - // the predicate is not here. let's remove it from our source - newSrc, err := RemovePredicate(ctx, expr, op.Source) - if err != nil { - return nil, err - } - op.Source = newSrc - return op, nil - } - if len(op.Predicates) == 1 { - // no predicates left on this operator, so we just remove it - return op.Source, nil - } - - // remove the predicate from this filter - op.Predicates = append(op.Predicates[:idx], op.Predicates[idx+1:]...) - return op, nil - - default: - return nil, vterrors.VT13001("this should not happen - tried to remove predicate from the operator table") - } -} diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 686950ba56d..e83dbbecdc3 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -149,7 +149,10 @@ func (sp StarProjections) GetSelectExprs() sqlparser.SelectExprs { func (ap AliasedProjections) GetColumns() ([]*sqlparser.AliasedExpr, error) { return slice.Map(ap, func(from *ProjExpr) *sqlparser.AliasedExpr { - return aeWrap(from.ColExpr) + return &sqlparser.AliasedExpr{ + As: from.Original.As, + Expr: from.ColExpr, + } }), nil } @@ -247,6 +250,13 @@ func (p *Projection) isDerived() bool { return p.DT != nil } +func (p *Projection) derivedName() string { + if p.DT == nil { + return "" + } + return p.DT.Alias +} + func (p *Projection) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) (int, error) { ap, err := p.GetAliasedProjections() if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 0af4270d71b..b531c23b877 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -19,6 +19,8 @@ package operators import ( "fmt" "io" + "slices" + "strconv" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" @@ -30,8 +32,9 @@ import ( type ( projector struct { columns []*ProjExpr - columnAliases sqlparser.Columns + columnAliases []string explicitColumnAliases bool + tableName sqlparser.TableName } ) @@ -277,13 +280,50 @@ func pushProjectionInVindex( return src, rewrite.NewTree("push projection into vindex", p), nil } -func (p *projector) add(pe *ProjExpr, col *sqlparser.IdentifierCI) { +func (p *projector) add(pe *ProjExpr, alias string) { p.columns = append(p.columns, pe) - if col != nil { - p.columnAliases = append(p.columnAliases, *col) + if alias != "" && slices.Index(p.columnAliases, alias) > -1 { + panic("alias already used") } } +// get finds or adds an expression in the projector, returning its SQL representation with the appropriate alias +func (p *projector) get(ctx *plancontext.PlanningContext, expr sqlparser.Expr) sqlparser.Expr { + for _, column := range p.columns { + if ctx.SemTable.EqualsExprWithDeps(expr, column.ColExpr) { + alias := p.claimUnusedAlias(column.Original) + out := sqlparser.NewColName(alias) + out.Qualifier = p.tableName + + ctx.SemTable.CopySemanticInfo(expr, out) + return out + } + } + + // we could not find the expression, so we add it + alias := sqlparser.UnescapedString(expr) + pe := newProjExpr(sqlparser.NewAliasedExpr(expr, alias)) + p.columns = append(p.columns, pe) + p.columnAliases = append(p.columnAliases, alias) + + out := sqlparser.NewColName(alias) + out.Qualifier = p.tableName + + ctx.SemTable.CopySemanticInfo(expr, out) + + return out +} + +// claimUnusedAlias generates a unique alias based on the provided expression, ensuring no duplication in the projector +func (p *projector) claimUnusedAlias(ae *sqlparser.AliasedExpr) string { + bare := ae.ColumnName() + alias := bare + for i := int64(0); slices.Index(p.columnAliases, alias) > -1; i++ { + alias = bare + strconv.FormatInt(i, 10) + } + return alias +} + // pushProjectionInApplyJoin pushes down a projection operation into an ApplyJoin operation. // It processes each input column and creates new JoinPredicates for the ApplyJoin operation based on // the input column's expression. It also creates new Projection operators for the left and right @@ -306,18 +346,18 @@ func pushProjectionInApplyJoin( src.JoinColumns = nil for idx, pe := range ap { - var col *sqlparser.IdentifierCI + var alias string if p.DT != nil && idx < len(p.DT.Columns) { - col = &p.DT.Columns[idx] + alias = p.DT.Columns[idx].String() } - err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, col) + err := splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, alias) if err != nil { return nil, nil, err } } if p.isDerived() { - err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs) + err := exposeColumnsThroughDerivedTable(ctx, p, src, lhs, rhs) if err != nil { return nil, nil, err } @@ -344,7 +384,7 @@ func splitProjectionAcrossJoin( join *ApplyJoin, lhs, rhs *projector, pe *ProjExpr, - colAlias *sqlparser.IdentifierCI, + alias string, ) error { // Check if the current expression can reuse an existing column in the ApplyJoin. @@ -352,7 +392,7 @@ func splitProjectionAcrossJoin( return nil } - col, err := splitUnexploredExpression(ctx, join, lhs, rhs, pe, colAlias) + col, err := splitUnexploredExpression(ctx, join, lhs, rhs, pe, alias) if err != nil { return err } @@ -367,7 +407,7 @@ func splitUnexploredExpression( join *ApplyJoin, lhs, rhs *projector, pe *ProjExpr, - colAlias *sqlparser.IdentifierCI, + alias string, ) (JoinColumn, error) { // Get a JoinColumn for the current expression. col, err := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false) @@ -378,23 +418,22 @@ func splitUnexploredExpression( // Update the left and right child columns and names based on the JoinColumn type. switch { case col.IsPureLeft(): - lhs.add(pe, colAlias) + lhs.add(pe, alias) case col.IsPureRight(): - rhs.add(pe, colAlias) + rhs.add(pe, alias) case col.IsMixedLeftAndRight(): for _, lhsExpr := range col.LHSExprs { - var lhsAlias *sqlparser.IdentifierCI - if colAlias != nil { + lhsAlias := "" + if alias != "" { // we need to add an explicit column alias here. let's try just the ColName as is first - ci := sqlparser.NewIdentifierCI(sqlparser.String(lhsExpr.Expr)) - lhsAlias = &ci + lhsAlias = sqlparser.String(lhsExpr.Expr) } lhs.add(newProjExpr(aeWrap(lhsExpr.Expr)), lhsAlias) } innerPE := newProjExprWithInner(pe.Original, col.RHSExpr) innerPE.ColExpr = col.RHSExpr innerPE.Info = pe.Info - rhs.add(innerPE, colAlias) + rhs.add(innerPE, alias) } return col, nil } @@ -411,7 +450,7 @@ func splitUnexploredExpression( // The function iterates through each join predicate, rewriting the expressions in the predicate's // LHS expressions to include the derived table. This allows the expressions to be accessed outside // the derived table. -func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector) error { +func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs, rhs *projector) error { derivedTbl, err := ctx.SemTable.TableInfoFor(p.DT.TableID) if err != nil { return err @@ -420,34 +459,75 @@ func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Proje if err != nil { return err } - for _, predicate := range src.JoinPredicates { - for idx, bve := range predicate.LHSExprs { - expr := bve.Expr - tbl, err := ctx.SemTable.TableInfoForExpr(expr) - if err != nil { - return err + lhs.tableName = derivedTblName + rhs.tableName = derivedTblName + + lhsIDs := TableID(src.LHS) + rhsIDs := TableID(src.RHS) + rewriteColumnsForJoin(ctx, src.JoinPredicates, lhsIDs, rhsIDs, lhs, rhs, true) + rewriteColumnsForJoin(ctx, src.JoinColumns, lhsIDs, rhsIDs, lhs, rhs, true) + return nil +} + +func rewriteColumnsForJoin( + ctx *plancontext.PlanningContext, + columns []JoinColumn, + lhsIDs, rhsIDs semantics.TableSet, + lhs, rhs *projector, + exposeRHS bool, // we only want to expose the returned columns from the RHS. + // For predicates, we don't need to expose the RHS columns +) { + for colIdx, column := range columns { + for lhsIdx, bve := range column.LHSExprs { + // since this is on the LHSExprs, we know that dependencies are from that side of the join + column.LHSExprs[lhsIdx].Expr = lhs.get(ctx, bve.Expr) + } + if column.IsPureLeft() { + continue + } + + // now we need to go over the predicate and find + var rewriteTo sqlparser.Expr + + pre := func(node, _ sqlparser.SQLNode) bool { + _, isSQ := node.(*sqlparser.Subquery) + if isSQ { + return false } - tblExpr := tbl.GetExpr() - tblName, err := tblExpr.TableName() - if err != nil { - return err + expr, ok := node.(sqlparser.Expr) + if !ok { + return true } + deps := ctx.SemTable.RecursiveDeps(expr) + + switch { + case deps.IsEmpty(): + return true + case deps.IsSolvedBy(lhsIDs): + rewriteTo = lhs.get(ctx, expr) + return false + case deps.IsSolvedBy(rhsIDs): + if exposeRHS { + rewriteTo = rhs.get(ctx, expr) + } + return false + default: + return true + } + } - expr = semantics.RewriteDerivedTableExpression(expr, derivedTbl) - out := prefixColNames(tblName, expr) - - alias := sqlparser.UnescapedString(out) - predicate.LHSExprs[idx].Expr = sqlparser.NewColNameWithQualifier(alias, derivedTblName) - identifierCI := sqlparser.NewIdentifierCI(alias) - projExpr := newProjExprWithInner(&sqlparser.AliasedExpr{Expr: out, As: identifierCI}, out) - var colAlias *sqlparser.IdentifierCI - if lhs.explicitColumnAliases { - colAlias = &identifierCI + post := func(cursor *sqlparser.CopyOnWriteCursor) { + if rewriteTo != nil { + cursor.Replace(rewriteTo) + rewriteTo = nil + return } - lhs.add(projExpr, colAlias) } + newOriginal := sqlparser.CopyOnRewrite(column.Original.Expr, pre, post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + column.Original.Expr = newOriginal + + columns[colIdx] = column } - return nil } // prefixColNames adds qualifier prefixes to all ColName:s. @@ -478,7 +558,6 @@ func createProjectionWithTheseColumns( proj.Columns = AliasedProjections(p.columns) if dt != nil { kopy := *dt - kopy.Columns = p.columnAliases proj.DT = &kopy } diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index e90948cbe5c..96edbf7f3e4 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -578,7 +578,10 @@ func (r *Route) AddColumn(ctx *plancontext.PlanningContext, reuse bool, gb bool, // if at least one column is not already present, we check if we can easily find a projection // or aggregation in our source that we can add to - derived, op, ok, offsets := addMultipleColumnsToInput(ctx, r.Source, reuse, []bool{gb}, []*sqlparser.AliasedExpr{expr}) + derived, op, ok, offsets, err := addMultipleColumnsToInput(ctx, r.Source, reuse, []bool{gb}, []*sqlparser.AliasedExpr{expr}) + if err != nil { + return 0, err + } r.Source = op if ok { return offsets[0], nil @@ -599,7 +602,7 @@ type selectExpressions interface { ops.Operator addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) (int, error) addColumnsWithoutPushing(ctx *plancontext.PlanningContext, reuse bool, addToGroupBy []bool, exprs []*sqlparser.AliasedExpr) ([]int, error) - isDerived() bool + derivedName() string } // addColumnToInput adds a column to an operator without pushing it down. @@ -615,63 +618,83 @@ func addMultipleColumnsToInput( projection ops.Operator, // if an operator needed to be built, it will be returned here found bool, // whether a matching op was found or not offsets []int, // the offsets the expressions received + err error, ) { switch op := operator.(type) { case *SubQuery: - derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Outer, reuse, addToGroupBy, exprs) + derivedName, src, added, offset, err := addMultipleColumnsToInput(ctx, op.Outer, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err + } if added { op.Outer = src } - return derivedName, op, added, offset + return derivedName, op, added, offset, nil case *Distinct: - derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset, err := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err + } if added { op.Source = src } - return derivedName, op, added, offset + return derivedName, op, added, offset, nil case *Limit: - derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset, err := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err + } if added { op.Source = src } - return derivedName, op, added, offset + return derivedName, op, added, offset, nil case *Ordering: - derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset, err := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err + } if added { op.Source = src } - return derivedName, op, added, offset + return derivedName, op, added, offset, nil case *LockAndComment: - derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset, err := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err + } if added { op.Source = src } - return derivedName, op, added, offset + return derivedName, op, added, offset, nil case *Horizon: // if the horizon has an alias, then it is a derived table, // we have to add a new projection and can't build on this one - return op.Alias, op, false, nil + return op.Alias, op, false, nil, nil case selectExpressions: - if op.isDerived() { + name := op.derivedName() + if name != "" { // if the only thing we can push to is a derived table, // we have to add a new projection and can't build on this one - return "", op, false, nil + return name, op, false, nil, nil + } + offset, err := op.addColumnsWithoutPushing(ctx, reuse, addToGroupBy, exprs) + if err != nil { + return "", nil, true, nil, err } - offset, _ := op.addColumnsWithoutPushing(ctx, reuse, addToGroupBy, exprs) - return "", op, true, offset + return "", op, true, offset, nil case *Union: tableID := semantics.SingleTableSet(len(ctx.SemTable.Tables)) ctx.SemTable.Tables = append(ctx.SemTable.Tables, nil) unionColumns, err := op.GetColumns(ctx) if err != nil { - return "", op, false, nil + break } proj := &Projection{ Source: op, @@ -682,9 +705,8 @@ func addMultipleColumnsToInput( }, } return addMultipleColumnsToInput(ctx, proj, reuse, addToGroupBy, exprs) - default: - return "", op, false, nil } + return "", operator, false, nil, nil } func (r *Route) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, _ bool) (int, error) { diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 93a00154dc1..710c186b9c2 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -382,7 +382,10 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr return nil, nil, vterrors.VT12001("JOIN between derived tables with LIMIT") } - join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner) + join, err := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, !inner) + if err != nil { + return nil, nil, err + } for _, pred := range joinPredicates { err := join.AddJoinPredicate(ctx, pred) if err != nil { @@ -392,7 +395,10 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr return join, rewrite.NewTree("logical join to applyJoin, switching side because LIMIT", join), nil } - join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner) + join, err := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, !inner) + if err != nil { + return nil, nil, err + } for _, pred := range joinPredicates { err := join.AddJoinPredicate(ctx, pred) if err != nil { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 127dd2c4837..499ee753971 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3442,8 +3442,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -3485,8 +3485,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -3590,9 +3590,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", + "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where val2 < 4) as x order by x.val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by x.val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -5895,8 +5895,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", - "Query": "select id, val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", + "FieldQuery": "select x.id, x.val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.id, x.val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", "Table": "`user`" } ] @@ -6475,24 +6475,31 @@ "OperatorType": "Aggregate", "Variant": "Scalar", "Aggregates": "count_star(0) AS count(*)", - "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Limit", - "Count": "INT64(25)", - "Offset": "INT64(0)", + "OperatorType": "SimpleProjection", + "Columns": [ + 2 + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1, id, weight_string(id) from (select 1 as one, id from `user` where 1 != 1) as subquery_for_count where 1 != 1", - "OrderBy": "(1|2) DESC", - "Query": "select 1, id, weight_string(id) from (select 1 as one, id from `user` where `user`.is_not_deleted = true) as subquery_for_count order by subquery_for_count.id desc limit :__upper_limit", - "Table": "`user`" + "OperatorType": "Limit", + "Count": "INT64(25)", + "Offset": "INT64(0)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select subquery_for_count.one, subquery_for_count.id, 1, weight_string(subquery_for_count.id) from (select 1 as one, id from `user` where 1 != 1) as subquery_for_count where 1 != 1", + "OrderBy": "(1|3) DESC", + "Query": "select subquery_for_count.one, subquery_for_count.id, 1, weight_string(subquery_for_count.id) from (select 1 as one, id from `user` where `user`.is_not_deleted = true) as subquery_for_count order by subquery_for_count.id desc limit :__upper_limit", + "Table": "`user`" + } + ] } ] } @@ -6518,24 +6525,31 @@ "OperatorType": "Aggregate", "Variant": "Scalar", "Aggregates": "count_star(0) AS count(*)", - "ResultColumns": 1, "Inputs": [ { - "OperatorType": "Limit", - "Count": "INT64(25)", - "Offset": "INT64(0)", + "OperatorType": "SimpleProjection", + "Columns": [ + 2 + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1, id, weight_string(id) from (select 1 as one, id from `user` where 1 != 1) as subquery_for_count where 1 != 1", - "OrderBy": "(1|2) DESC", - "Query": "select 1, id, weight_string(id) from (select 1 as one, id from `user` where `user`.is_not_deleted = true) as subquery_for_count order by subquery_for_count.id desc limit :__upper_limit", - "Table": "`user`" + "OperatorType": "Limit", + "Count": "INT64(25)", + "Offset": "INT64(0)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select subquery_for_count.one, subquery_for_count.id, 1, weight_string(subquery_for_count.id) from (select 1 as one, id from `user` where 1 != 1) as subquery_for_count where 1 != 1", + "OrderBy": "(1|3) DESC", + "Query": "select subquery_for_count.one, subquery_for_count.id, 1, weight_string(subquery_for_count.id) from (select 1 as one, id from `user` where `user`.is_not_deleted = true) as subquery_for_count order by subquery_for_count.id desc limit :__upper_limit", + "Table": "`user`" + } + ] } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index f977a5ca7a8..9e668bd68a2 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -511,7 +511,7 @@ "Sharded": true }, "FieldQuery": "select 42 from `user` as u, user_extra as ue, music as m where 1 != 1", - "Query": "select 42 from `user` as u, user_extra as ue, music as m where u.id = ue.user_id and m.user_id = u.id and (u.foo or m.foo or ue.foo)", + "Query": "select 42 from `user` as u, user_extra as ue, music as m where u.id = ue.user_id and u.id = :m_user_id and (u.foo or :m_foo or ue.foo) and m.user_id = u.id and (u.foo or m.foo or ue.foo)", "Table": "`user`, music, user_extra" }, "TablesUsed": [ @@ -1598,7 +1598,7 @@ "Sharded": true }, "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t, user_extra where 1 != 1", - "Query": "select t.id from (select id from `user` where id = 5) as t, user_extra where t.id = user_extra.user_id", + "Query": "select t.id from (select id from `user` where id = 5 and id = :user_extra_user_id) as t, user_extra where t.id = user_extra.user_id", "Table": "`user`, user_extra", "Values": [ "INT64(5)" @@ -1736,7 +1736,7 @@ "Sharded": true }, "FieldQuery": "select t.id from (select id, textcol1 as baz from `user` as route1 where 1 != 1) as t, (select id, textcol1 + textcol1 as baz from `user` where 1 != 1) as s where 1 != 1", - "Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3') as s where t.id = s.id", + "Query": "select t.id from (select id, textcol1 as baz from `user` as route1 where textcol1 = '3') as t, (select id, textcol1 + textcol1 as baz from `user` where textcol1 + textcol1 = '3' and id = :t_id) as s where t.id = s.id", "Table": "`user`" }, "TablesUsed": [ @@ -2000,8 +2000,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col", + "FieldQuery": "select 1 from (select user_extra.col as `user_extra.col` from user_extra where 1 != 1) as t where 1 != 1", + "Query": "select 1 from (select user_extra.col as `user_extra.col` from user_extra where user_extra.col = :user_col) as t", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/large_union_cases.json b/go/vt/vtgate/planbuilder/testdata/large_union_cases.json index 1ad7b33d589..bcf68e30e57 100644 --- a/go/vt/vtgate/planbuilder/testdata/large_union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/large_union_cases.json @@ -23,8 +23,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", - "Query": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11)) as dt", + "FieldQuery": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11)) as dt", "Table": "music", "Values": [ "INT64(1270698330)" @@ -38,8 +38,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", - "Query": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11)) as dt", + "FieldQuery": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11)) as dt", "Table": "music", "Values": [ "INT64(1270699497)" diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index a36ad580dc4..7f573227e65 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -1052,7 +1052,7 @@ "Sharded": true }, "FieldQuery": "select * from (select user_id from user_extra where 1 != 1) as eu, `user` as u where 1 != 1", - "Query": "select * from (select user_id from user_extra where user_id = 5) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", + "Query": "select * from (select user_id from user_extra where user_id = 5 and user_id = :u_id) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", "Table": "`user`, user_extra", "Values": [ "INT64(5)" diff --git a/go/vt/vtgate/planbuilder/testdata/reference_cases.json b/go/vt/vtgate/planbuilder/testdata/reference_cases.json index 53019a48e12..140ce9f7849 100644 --- a/go/vt/vtgate/planbuilder/testdata/reference_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/reference_cases.json @@ -876,19 +876,19 @@ "QueryType": "SELECT", "Original": "select * from user.ref_with_source ref, `user`.`user` u where ref.id = u.ref_id and u.id = 2", "Instructions": { - "FieldQuery": "select * from ref_with_source as ref, `user` as u where 1 != 1", "OperatorType": "Route", "Variant": "EqualUnique", - "Vindex": "user_index", "Keyspace": { "Name": "user", "Sharded": true }, + "FieldQuery": "select * from ref_with_source as ref, `user` as u where 1 != 1", "Query": "select * from ref_with_source as ref, `user` as u where u.id = 2 and ref.id = u.ref_id", "Table": "`user`, ref_with_source", "Values": [ "INT64(2)" - ] + ], + "Vindex": "user_index" }, "TablesUsed": [ "user.ref_with_source", @@ -903,19 +903,19 @@ "QueryType": "SELECT", "Original": "select * from source_of_ref ref, `user`.`user` u where ref.id = u.ref_id and u.id = 2", "Instructions": { - "FieldQuery": "select * from ref_with_source as ref, `user` as u where 1 != 1", "OperatorType": "Route", "Variant": "EqualUnique", - "Vindex": "user_index", "Keyspace": { "Name": "user", "Sharded": true }, + "FieldQuery": "select * from ref_with_source as ref, `user` as u where 1 != 1", "Query": "select * from ref_with_source as ref, `user` as u where u.id = 2 and ref.id = u.ref_id", "Table": "`user`, ref_with_source", "Values": [ "INT64(2)" - ] + ], + "Vindex": "user_index" }, "TablesUsed": [ "user.ref_with_source", @@ -937,7 +937,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where 1 != 1", - "Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo", + "Query": "select 1 from `user` as u, user_extra as ue, ref_with_source as sr, ref as rr where sr.foo = :ue_foo and rr.bar = sr.bar and u.id = ue.user_id and sr.foo = ue.foo", "Table": "`user`, ref, ref_with_source, user_extra" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 016cfdb0fd1..6c255fd9d89 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4110,7 +4110,7 @@ "Sharded": true }, "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", - "Query": "select music.id from (select id from music where music.user_id = 5) as other, music where other.id = music.id", + "Query": "select music.id from (select id from music where music.user_id = 5 and id = :music_id) as other, music where other.id = music.id", "Table": "music", "Values": [ "INT64(5)" @@ -4136,7 +4136,7 @@ "Sharded": true }, "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", - "Query": "select music.id from (select id from music where music.user_id in ::__vals) as other, music where other.id = music.id", + "Query": "select music.id from (select id from music where music.user_id in ::__vals and id = :music_id) as other, music where other.id = music.id", "Table": "music", "Values": [ "(INT64(5), INT64(6), INT64(7))" @@ -4174,8 +4174,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select user_id from (select user_id from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select user_id from (select user_id from user_extra) as ue limit :__upper_limit", + "FieldQuery": "select ue.user_id from (select user_id from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.user_id from (select user_id from user_extra) as ue limit :__upper_limit", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json index 2677deb2cab..fa823b0ae59 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json @@ -13,7 +13,7 @@ "Sharded": true }, "FieldQuery": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where 1 != 1", - "Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id", + "Query": "select c_discount, c_last, c_credit, w_tax from customer1 as c, warehouse1 as w where c_w_id = :w_id and c_d_id = 15 and c_id = 10 and w_id = 1 and c_w_id = w_id", "Table": "customer1, warehouse1", "Values": [ "INT64(1)" @@ -947,7 +947,7 @@ "Sharded": true }, "FieldQuery": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where 1 != 1 group by o_c_id, o_d_id, o_w_id) as t, orders1 as o where 1 != 1", - "Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", + "Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 and o_w_id = :o_o_w_id and o_d_id = :o_o_d_id and o_c_id = :o_o_c_id group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", "Table": "orders1", "Values": [ "INT64(1)" diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 0a20b7e06da..f1d0ba7dfcc 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -535,9 +535,9 @@ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:5,R:2,L:6", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:4,R:2,L:5", "JoinVars": { - "n1_n_name": 4, + "n1_n_name": 1, "o_custkey": 3 }, "TableName": "lineitem_orders_supplier_nation_customer_nation", @@ -548,18 +548,17 @@ "[COLUMN 0] * [COLUMN 1] as revenue", "[COLUMN 2] as supp_nation", "[COLUMN 3] as l_year", - "[COLUMN 4] as orders.o_custkey", - "[COLUMN 5] as n1.n_name", - "[COLUMN 6] as weight_string(supp_nation)", - "[COLUMN 7] as weight_string(l_year)" + "[COLUMN 4] as o_custkey", + "[COLUMN 5] as weight_string(supp_nation)", + "[COLUMN 6] as weight_string(l_year)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,R:1,L:1,L:2,L:3,R:2,L:5", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1,L:2,R:2,L:4", "JoinVars": { - "l_suppkey": 4 + "l_suppkey": 3 }, "TableName": "lineitem_orders_supplier_nation", "Inputs": [ @@ -568,18 +567,17 @@ "Expressions": [ "[COLUMN 0] * [COLUMN 1] as revenue", "[COLUMN 2] as l_year", - "[COLUMN 3] as orders.o_custkey", - "[COLUMN 4] as n1.n_name", - "[COLUMN 5] as lineitem.l_suppkey", - "[COLUMN 6] as weight_string(l_year)" + "[COLUMN 3] as o_custkey", + "[COLUMN 4] as l_suppkey", + "[COLUMN 5] as weight_string(l_year)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,L:1,L:2,L:3,L:4,L:6", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:4", "JoinVars": { - "l_orderkey": 5 + "l_orderkey": 3 }, "TableName": "lineitem_orders", "Inputs": [ @@ -590,9 +588,9 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year)", - "OrderBy": "(7|8) ASC, (9|10) ASC, (1|6) ASC", - "Query": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year) order by shipping.supp_nation asc, shipping.cust_nation asc, shipping.l_year asc", + "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", + "OrderBy": "(5|6) ASC, (7|8) ASC, (1|4) ASC", + "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by shipping.supp_nation asc, shipping.cust_nation asc, shipping.l_year asc", "Table": "lineitem" }, { @@ -602,8 +600,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*) from orders where 1 != 1 group by .0", - "Query": "select count(*) from orders where o_orderkey = :l_orderkey group by .0", + "FieldQuery": "select count(*), shipping.o_custkey from (select o_custkey as o_custkey, o_orderkey as o_orderkey from orders where 1 != 1) as shipping where 1 != 1 group by shipping.o_custkey", + "Query": "select count(*), shipping.o_custkey from (select o_custkey as o_custkey, o_orderkey as o_orderkey from orders where o_orderkey = :l_orderkey) as shipping group by shipping.o_custkey", "Table": "orders", "Values": [ ":l_orderkey" @@ -638,8 +636,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.`supplier.s_nationkey`", - "Query": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.`supplier.s_nationkey`", + "FieldQuery": "select count(*), shipping.s_nationkey from (select s_suppkey as s_suppkey, s_nationkey as s_nationkey from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.s_nationkey", + "Query": "select count(*), shipping.s_nationkey from (select s_suppkey as s_suppkey, s_nationkey as s_nationkey from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.s_nationkey", "Table": "supplier", "Values": [ ":l_suppkey" @@ -653,8 +651,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where 1 != 1) as shipping where 1 != 1 group by supp_nation, weight_string(supp_nation)", - "Query": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where n1.n_nationkey = :s_nationkey) as shipping group by supp_nation, weight_string(supp_nation)", + "FieldQuery": "select count(*), supp_nation, weight_string(supp_nation), n1.n_nationkey as `n1.n_nationkey`, shipping.supp_nation as `shipping.supp_nation`, shipping.`n1.n_name = 'FRANCE'` as `shipping.n1.n_name = 'FRANCE'`, shipping.`n1.n_name = 'GERMANY'` as `shipping.n1.n_name = 'GERMANY'` from (select n1.n_name as supp_nation, n1.n_name = 'FRANCE' as `n1.n_name = 'FRANCE'`, n1.n_name = 'GERMANY' as `n1.n_name = 'GERMANY'`, n1.n_nationkey as `n1.n_nationkey`, shipping.supp_nation as `shipping.supp_nation`, shipping.`n1.n_name = 'FRANCE'` as `shipping.n1.n_name = 'FRANCE'`, shipping.`n1.n_name = 'GERMANY'` as `shipping.n1.n_name = 'GERMANY'` from nation as n1 where 1 != 1) as shipping where 1 != 1 group by supp_nation, weight_string(supp_nation)", + "Query": "select count(*), supp_nation, weight_string(supp_nation), n1.n_nationkey as `n1.n_nationkey`, shipping.supp_nation as `shipping.supp_nation`, shipping.`n1.n_name = 'FRANCE'` as `shipping.n1.n_name = 'FRANCE'`, shipping.`n1.n_name = 'GERMANY'` as `shipping.n1.n_name = 'GERMANY'` from (select n1.n_name as supp_nation, n1.n_name = 'FRANCE' as `n1.n_name = 'FRANCE'`, n1.n_name = 'GERMANY' as `n1.n_name = 'GERMANY'`, n1.n_nationkey as `n1.n_nationkey`, shipping.supp_nation as `shipping.supp_nation`, shipping.`n1.n_name = 'FRANCE'` as `shipping.n1.n_name = 'FRANCE'`, shipping.`n1.n_name = 'GERMANY'` as `shipping.n1.n_name = 'GERMANY'` from nation as n1 where n1.n_nationkey = :s_nationkey) as shipping group by supp_nation, weight_string(supp_nation)", "Table": "nation", "Values": [ ":s_nationkey" @@ -693,8 +691,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where 1 != 1) as shipping where 1 != 1 group by shipping.`customer.c_nationkey`", - "Query": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where c_custkey = :o_custkey) as shipping group by shipping.`customer.c_nationkey`", + "FieldQuery": "select count(*), shipping.c_nationkey from (select c_custkey as c_custkey, c_nationkey as c_nationkey from customer where 1 != 1) as shipping where 1 != 1 group by shipping.c_nationkey", + "Query": "select count(*), shipping.c_nationkey from (select c_custkey as c_custkey, c_nationkey as c_nationkey from customer where c_custkey = :o_custkey) as shipping group by shipping.c_nationkey", "Table": "customer", "Values": [ ":o_custkey" @@ -708,8 +706,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), cust_nation, weight_string(cust_nation) from (select n2.n_name as cust_nation from nation as n2 where 1 != 1) as shipping where 1 != 1 group by cust_nation, weight_string(cust_nation)", - "Query": "select count(*), cust_nation, weight_string(cust_nation) from (select n2.n_name as cust_nation from nation as n2 where (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') and n2.n_nationkey = :c_nationkey) as shipping group by cust_nation, weight_string(cust_nation)", + "FieldQuery": "select count(*), cust_nation, weight_string(cust_nation), n2.n_nationkey as `n2.n_nationkey`, shipping.cust_nation as `shipping.cust_nation` from (select n2.n_name as cust_nation, n2.n_name = 'GERMANY' as `n2.n_name = 'GERMANY'`, n2.n_name = 'FRANCE' as `n2.n_name = 'FRANCE'`, n2.n_nationkey as `n2.n_nationkey`, shipping.cust_nation as `shipping.cust_nation` from nation as n2 where 1 != 1) as shipping where 1 != 1 group by cust_nation, weight_string(cust_nation)", + "Query": "select count(*), cust_nation, weight_string(cust_nation), n2.n_nationkey as `n2.n_nationkey`, shipping.cust_nation as `shipping.cust_nation` from (select n2.n_name as cust_nation, n2.n_name = 'GERMANY' as `n2.n_name = 'GERMANY'`, n2.n_name = 'FRANCE' as `n2.n_name = 'FRANCE'`, n2.n_nationkey as `n2.n_nationkey`, shipping.cust_nation as `shipping.cust_nation` from nation as n2 where (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') and n2.n_nationkey = :c_nationkey) as shipping group by cust_nation, weight_string(cust_nation)", "Table": "nation", "Values": [ ":c_nationkey" diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 8b4a2d91cdb..17510c62d57 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -42,8 +42,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id from music) as dt", "Table": "`user`, music" } ] @@ -384,8 +384,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1 union select 1 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music union select 1 from dual) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1 union select 1 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id from music union select 1 from dual) as dt", "Table": "`user`, dual, music" } ] @@ -503,8 +503,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from ((select id from `user` where 1 != 1) union (select id from `user` where 1 != 1)) as dt where 1 != 1", - "Query": "select id, weight_string(id) from ((select id from `user` order by id desc) union (select id from `user` order by id asc)) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from ((select id from `user` where 1 != 1) union (select id from `user` where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from ((select id from `user` order by id desc) union (select id from `user` order by id asc)) as dt", "Table": "`user`" } ] @@ -534,8 +534,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `1`, weight_string(`1`) from (select 1 from dual where 1 != 1 union select null from dual where 1 != 1 union select 1.0 from dual where 1 != 1 union select '1' from dual where 1 != 1 union select 2 from dual where 1 != 1 union select 2.0 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `1`, weight_string(`1`) from (select 1 from dual union select null from dual union select 1.0 from dual union select '1' from dual union select 2 from dual union select 2.0 from `user`) as dt", + "FieldQuery": "select dt.`1`, weight_string(dt.`1`) from (select 1 from dual where 1 != 1 union select null from dual where 1 != 1 union select 1.0 from dual where 1 != 1 union select '1' from dual where 1 != 1 union select 2 from dual where 1 != 1 union select 2.0 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`1`, weight_string(dt.`1`) from (select 1 from dual union select null from dual union select 1.0 from dual union select '1' from dual union select 2 from dual union select 2.0 from `user`) as dt", "Table": "`user`, dual" } ] @@ -767,8 +767,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id + 1 from `user` where 1 != 1 union select user_id from user_extra where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id + 1 from `user` union select user_id from user_extra) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id + 1 from `user` where 1 != 1 union select user_id from user_extra where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id + 1 from `user` union select user_id from user_extra) as dt", "Table": "`user`, user_extra" } ] @@ -804,8 +804,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music) as dt", + "FieldQuery": "select id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", + "Query": "select id, weight_string(dt.id) from (select id from `user` union select id from music) as dt", "Table": "`user`, music" }, { @@ -854,8 +854,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select 3 from dual limit :__upper_limit) as dt", + "FieldQuery": "select id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select id, weight_string(dt.id) from (select id from `user` union select 3 from dual limit :__upper_limit) as dt", "Table": "`user`, dual" } ] @@ -913,8 +913,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select col, weight_string(col) from (select col from unsharded where 1 != 1 union select col2 from unsharded where 1 != 1) as dt where 1 != 1", - "Query": "select col, weight_string(col) from (select col from unsharded union select col2 from unsharded) as dt", + "FieldQuery": "select dt.col, weight_string(col) from (select col from unsharded where 1 != 1 union select col2 from unsharded where 1 != 1) as dt where 1 != 1", + "Query": "select dt.col, weight_string(col) from (select col from unsharded union select col2 from unsharded) as dt", "Table": "unsharded" }, { @@ -1071,8 +1071,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select 3 from dual) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select 3 from dual) as dt", "Table": "`user`, dual" } ] @@ -1450,8 +1450,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select bar, baz, toto, weight_string(bar), weight_string(baz), weight_string(toto) from (select bar, baz, toto from music where 1 != 1 union select foo, foo, foo from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select bar, baz, toto, weight_string(bar), weight_string(baz), weight_string(toto) from (select bar, baz, toto from music union select foo, foo, foo from `user`) as dt", + "FieldQuery": "select dt.bar, dt.baz, dt.toto, weight_string(dt.bar), weight_string(dt.baz), weight_string(dt.toto) from (select bar, baz, toto from music where 1 != 1 union select foo, foo, foo from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.bar, dt.baz, dt.toto, weight_string(dt.bar), weight_string(dt.baz), weight_string(dt.toto) from (select bar, baz, toto from music union select foo, foo, foo from `user`) as dt", "Table": "`user`, music" } ] @@ -1484,8 +1484,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select foo, foo, foo, weight_string(foo) from (select foo, foo, foo from `user` where 1 != 1 union select bar, baz, toto from music where 1 != 1) as dt where 1 != 1", - "Query": "select foo, foo, foo, weight_string(foo) from (select foo, foo, foo from `user` union select bar, baz, toto from music) as dt", + "FieldQuery": "select dt.foo, dt.foo, dt.foo, weight_string(dt.foo) from (select foo, foo, foo from `user` where 1 != 1 union select bar, baz, toto from music where 1 != 1) as dt where 1 != 1", + "Query": "select dt.foo, dt.foo, dt.foo, weight_string(dt.foo) from (select foo, foo, foo from `user` union select bar, baz, toto from music) as dt", "Table": "`user`, music" } ] @@ -1546,8 +1546,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, weight_string(col1) from (select col1 from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select col1, weight_string(col1) from (select col1 from `user` union select 3 from `user`) as dt", + "FieldQuery": "select dt.col1, weight_string(dt.col1) from (select col1 from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.col1, weight_string(dt.col1) from (select col1 from `user` union select 3 from `user`) as dt", "Table": "`user`" } ] @@ -1577,8 +1577,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `3`, weight_string(`3`) from (select 3 from `user` where 1 != 1 union select col1 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `3`, weight_string(`3`) from (select 3 from `user` union select col1 from `user`) as dt", + "FieldQuery": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` where 1 != 1 union select col1 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` union select col1 from `user`) as dt", "Table": "`user`" } ] @@ -1608,8 +1608,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `3`, weight_string(`3`) from (select 3 from `user` where 1 != 1 union select now() from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `3`, weight_string(`3`) from (select 3 from `user` union select now() from `user`) as dt", + "FieldQuery": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` where 1 != 1 union select now() from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` union select now() from `user`) as dt", "Table": "`user`" } ] @@ -1639,8 +1639,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `now()`, weight_string(`now()`) from (select now() from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `now()`, weight_string(`now()`) from (select now() from `user` union select 3 from `user`) as dt", + "FieldQuery": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` union select 3 from `user`) as dt", "Table": "`user`" } ] @@ -1670,8 +1670,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `now()`, weight_string(`now()`) from (select now() from `user` where 1 != 1 union select id from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `now()`, weight_string(`now()`) from (select now() from `user` union select id from `user`) as dt", + "FieldQuery": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` where 1 != 1 union select id from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` union select id from `user`) as dt", "Table": "`user`" } ] From 1ae91e382ddad916b5283fc6355d3572271caac9 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Thu, 17 Oct 2024 11:33:06 +0200 Subject: [PATCH 5/5] [release-18.0] bugfix: add HAVING columns inside derived tables (#16976) (#16977) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Andres Taylor Co-authored-by: Andrés Taylor --- .../operators/horizon_expanding.go | 7 ++ .../planbuilder/testdata/aggr_cases.json | 75 +++++++++++++++---- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index a2f0799c547..59892bc13da 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -90,6 +90,12 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel for _, order := range horizon.Query.GetOrderBy() { qp.addDerivedColumn(ctx, order.Expr) } + sel, isSel := horizon.Query.(*sqlparser.Select) + if isSel && sel.Having != nil { + for _, pred := range sqlparser.SplitAndExpression(nil, sel.Having.Expr) { + qp.addDerivedColumn(ctx, pred) + } + } } op, err := createProjectionFromSelect(ctx, horizon) @@ -307,6 +313,7 @@ outer: func createProjectionForComplexAggregation(a *Aggregator, qp *QueryProjection) (ops.Operator, error) { p := newAliasedProjection(a) p.DT = a.DT + a.DT = nil // we don't need the derived table twice for _, expr := range qp.SelectExprs { ae, err := expr.GetAliasedExpr() if err != nil { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 499ee753971..e2fae908825 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -571,6 +571,35 @@ ] } }, + { + "comment": "using HAVING inside a derived table still produces viable plans", + "query": "select id from (select id from user group by id having (count(user.id) = 2) limit 2 offset 0) subquery_for_limit", + "plan": { + "QueryType": "SELECT", + "Original": "select id from (select id from user group by id having (count(user.id) = 2) limit 2 offset 0) subquery_for_limit", + "Instructions": { + "OperatorType": "Limit", + "Count": "INT64(2)", + "Offset": "INT64(0)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id, count(`user`.id) = 2 from `user` where 1 != 1 group by id) as subquery_for_limit where 1 != 1", + "Query": "select id from (select id, count(`user`.id) = 2 from `user` group by id having count(`user`.id) = 2) as subquery_for_limit limit :__upper_limit", + "Table": "`user`" + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, { "comment": "sum with distinct no unique vindex", "query": "select col1, sum(distinct col2) from user group by col1", @@ -3613,25 +3642,41 @@ "QueryType": "SELECT", "Original": "select * from (select id from user having count(*) = 1) s", "Instructions": { - "OperatorType": "Filter", - "Predicate": "count(*) = 1", - "ResultColumns": 1, + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*)", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] as id", + "[COLUMN 1] = [COLUMN 2] as count(*) = 1" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, count(*) from `user` where 1 != 1", - "Query": "select id, count(*) from `user`", - "Table": "`user`" + "OperatorType": "Filter", + "Predicate": "count(*) = 1", + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "any_value(0) AS id, sum_count_star(1) AS count(*), any_value(2)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, count(*), 1 from `user` where 1 != 1", + "Query": "select id, count(*), 1 from `user`", + "Table": "`user`" + } + ] + } + ] } ] }