Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert GRPC context changes #15780

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/local/vstream_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func main() {
Filter: "select * from customer",
}},
}
conn, err := vtgateconn.Dial("localhost:15991")
conn, err := vtgateconn.Dial(ctx, "localhost:15991")
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/cluster/cluster_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ func (cluster *LocalProcessCluster) ExecOnVTGate(ctx context.Context, addr strin
return nil, err
}

conn, err := vtgateconn.Dial(addr)
conn, err := vtgateconn.Dial(ctx, addr)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/cluster/cluster_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,13 @@ func WaitForHealthyShard(vtctldclient *VtctldClientProcess, keyspace, shard stri
}

// DialVTGate returns a VTGate grpc connection.
func DialVTGate(name, addr, username, password string) (*vtgateconn.VTGateConn, error) {
func DialVTGate(ctx context.Context, name, addr, username, password string) (*vtgateconn.VTGateConn, error) {
clientCreds := &grpcclient.StaticAuthClientCreds{Username: username, Password: password}
creds := grpc.WithPerRPCCredentials(clientCreds)
dialerFunc := grpcvtgateconn.Dial(creds)
dialerName := name
vtgateconn.RegisterDialer(dialerName, dialerFunc)
return vtgateconn.DialProtocol(dialerName, addr)
return vtgateconn.DialProtocol(ctx, dialerName, addr)
}

// PrintFiles prints the files that are asked for. If no file is specified, all the files are printed.
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/messaging/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func VtgateGrpcConn(ctx context.Context, cluster *cluster.LocalProcessCluster) (
stream := new(VTGateStream)
stream.ctx = ctx
stream.host = fmt.Sprintf("%s:%d", cluster.Hostname, cluster.VtgateProcess.GrpcPort)
conn, err := vtgateconn.Dial(stream.host)
conn, err := vtgateconn.Dial(ctx, stream.host)
// init components
stream.respChan = make(chan *sqltypes.Result)
stream.VTGateConn = conn
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/recovery/unshardedrecovery/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func TestRecoveryImpl(t *testing.T) {

// Build vtgate grpc connection
grpcAddress := fmt.Sprintf("%s:%d", localCluster.Hostname, localCluster.VtgateGrpcPort)
vtgateConn, err := vtgateconn.Dial(grpcAddress)
vtgateConn, err := vtgateconn.Dial(context.Background(), grpcAddress)
assert.NoError(t, err)
defer vtgateConn.Close()
session := vtgateConn.Session("@replica", nil)
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/tabletgateway/vtgate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func TestStreamingRPCStuck(t *testing.T) {
}

// Connect to vtgate and run a streaming query.
vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "test_user", "")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "test_user", "")
require.NoError(t, err)
stream, err := vtgateConn.Session("", &querypb.ExecuteOptions{}).StreamExecute(ctx, "select * from customer", map[string]*querypb.BindVariable{})
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/vreplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func testVStreamCellFlag(t *testing.T) {

for _, tc := range vstreamTestCases {
t.Run("VStreamCellsFlag/"+tc.cells, func(t *testing.T) {
conn, err := vtgateconn.Dial(fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort))
conn, err := vtgateconn.Dial(ctx, fmt.Sprintf("localhost:%d", vc.ClusterConfig.vtgateGrpcPort))
require.NoError(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/vschema_load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestVSchemaChangesUnderLoad(t *testing.T) {
Filter: "select * from customer",
}},
}
conn, err := vtgateconn.Dial(net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort)))
conn, err := vtgateconn.Dial(ctx, net.JoinHostPort("localhost", strconv.Itoa(vc.ClusterConfig.vtgateGrpcPort)))
require.NoError(t, err)
defer conn.Close()

Expand Down
8 changes: 4 additions & 4 deletions go/test/endtoend/vreplication/vstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func testVStreamWithFailover(t *testing.T, failover bool) {
testVStreamFrom(t, vtgate, "product", 2)
})
ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func testVStreamStopOnReshardFlag(t *testing.T, stopOnReshard bool, baseTabletID
vc.AddKeyspace(t, []*Cell{defaultCell}, "sharded", "-80,80-", vschemaSharded, schemaSharded, defaultReplicas, defaultRdonly, baseTabletID+200, nil)

ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -398,7 +398,7 @@ func testVStreamCopyMultiKeyspaceReshard(t *testing.T, baseTabletID int) numEven
require.NoError(t, err)

ctx := context.Background()
vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -550,7 +550,7 @@ func TestMultiVStreamsKeyspaceReshard(t *testing.T) {
defer vtgateConn.Close()
verifyClusterHealth(t, vc)

vstreamConn, err := vtgateconn.Dial(fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
vstreamConn, err := vtgateconn.Dial(ctx, fmt.Sprintf("%s:%d", vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateGrpcPort))
require.NoError(t, err)
defer vstreamConn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtcombo/recreate/recreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestMain(m *testing.M) {

func TestDropAndRecreateWithSameShards(t *testing.T) {
ctx := context.Background()
conn, err := vtgateconn.Dial(grpcAddress)
conn, err := vtgateconn.Dial(ctx, grpcAddress)
require.Nil(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtcombo/vttest_sample_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestStandalone(t *testing.T) {
require.Contains(t, tmp[0], "vtcombo")

ctx := context.Background()
conn, err := vtgateconn.Dial(grpcAddress)
conn, err := vtgateconn.Dial(ctx, grpcAddress)
require.NoError(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestUpdateWithFK(t *testing.T) {

// TestVstreamForFKBinLog tests that dml queries with fks are written with child row first approach in the binary logs.
func TestVstreamForFKBinLog(t *testing.T) {
vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "fk_user", "")
vtgateConn, err := cluster.DialVTGate(context.Background(), t.Name(), vtgateGrpcAddress, "fk_user", "")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
10 changes: 5 additions & 5 deletions go/test/endtoend/vtgate/grpc_api/acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestEffectiveCallerIDWithAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "some_other_user", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "some_other_user", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -48,7 +48,7 @@ func TestEffectiveCallerIDWithNoAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "another_unrelated_user", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -66,7 +66,7 @@ func TestAuthenticatedUserWithAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -81,7 +81,7 @@ func TestAuthenticatedUserNoAccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_no_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_no_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand All @@ -98,7 +98,7 @@ func TestUnauthenticatedUser(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "", "")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "", "")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/grpc_api/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestTransactionsWithGRPCAPI(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

vtgateConn, err := cluster.DialVTGate(t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
vtgateConn, err := cluster.DialVTGate(ctx, t.Name(), vtgateGrpcAddress, "user_with_access", "test_password")
require.NoError(t, err)
defer vtgateConn.Close()

Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vtgate/queries/reference/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestMain(m *testing.M) {
go func() {
ctx := context.Background()
vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort)
vtgateConn, err := vtgateconn.Dial(vtgateAddr)
vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr)
if err != nil {
done <- false
return
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestMain(m *testing.M) {

ctx := context.Background()
vtgateAddr := fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.GrpcPort)
vtgateConn, err := vtgateconn.Dial(vtgateAddr)
vtgateConn, err := vtgateconn.Dial(ctx, vtgateAddr)
if err != nil {
return 1
}
Expand Down
13 changes: 12 additions & 1 deletion go/vt/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
package grpcclient

import (
"context"
"crypto/tls"
"sync"
"time"
Expand Down Expand Up @@ -96,6 +97,16 @@ func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([
// failFast is a non-optional parameter because callers are required to specify
// what that should be.
func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return DialContext(context.Background(), target, failFast, opts...)
}

// DialContext creates a grpc connection to the given target. Setup steps are
// covered by the context deadline, and, if WithBlock is specified in the dial
// options, connection establishment steps are covered by the context as well.
//
// failFast is a non-optional parameter because callers are required to specify
// what that should be.
func DialContext(ctx context.Context, target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
msgSize := grpccommon.MaxMessageSize()
newopts := []grpc.DialOption{
grpc.WithDefaultCallOptions(
Expand Down Expand Up @@ -138,7 +149,7 @@ func Dial(target string, failFast FailFast, opts ...grpc.DialOption) (*grpc.Clie

newopts = append(newopts, interceptors()...)

return grpc.Dial(target, newopts...)
return grpc.DialContext(ctx, target, newopts...)
}

func interceptors() []grpc.DialOption {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/grpcoptionaltls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestOptionalTLS(t *testing.T) {
testFunc := func(t *testing.T, dialOpt grpc.DialOption) {
ctx, cancel := context.WithTimeout(testCtx, 5*time.Second)
defer cancel()
conn, err := grpc.NewClient(addr, dialOpt)
conn, err := grpc.DialContext(ctx, addr, dialOpt)
if err != nil {
t.Fatalf("failed to connect to the server %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ func (d drv) newConnector(cfg Configuration) (driver.Connector, error) {
}

// Connect implements the database/sql/driver.Connector interface.
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
conn := &conn{
cfg: c.cfg,
convert: c.convert,
}

if err := conn.dial(); err != nil {
if err := conn.dial(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -267,9 +267,9 @@ type conn struct {
session *vtgateconn.VTGateSession
}

func (c *conn) dial() error {
func (c *conn) dial(ctx context.Context) error {
var err error
c.conn, err = vtgateconn.DialProtocol(c.cfg.Protocol, c.cfg.Address)
c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtadmin/grpcserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestServer(t *testing.T) {
}
close(readyCh)

conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
assert.NoError(t, err)

defer conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/endtoend/vstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import (
)

func initialize(ctx context.Context, t *testing.T) (*vtgateconn.VTGateConn, *mysql.Conn, *mysql.Conn, func()) {
gconn, err := vtgateconn.Dial(grpcAddress)
gconn, err := vtgateconn.Dial(ctx, grpcAddress)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/fakerpcvtgateconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func RegisterFakeVTGateConnDialer() (*FakeVTGateConn, string) {
impl := &FakeVTGateConn{
execMap: make(map[string]*queryResponse),
}
vtgateconn.RegisterDialer(protocol, func(address string) (vtgateconn.Impl, error) {
vtgateconn.RegisterDialer(protocol, func(ctx context.Context, address string) (vtgateconn.Impl, error) {
return impl, nil
})
return impl, protocol
Expand Down
16 changes: 12 additions & 4 deletions go/vt/vtgate/grpcvtgateconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ type vtgateConn struct {
c vtgateservicepb.VitessClient
}

func dial(addr string) (vtgateconn.Impl, error) {
return Dial()(addr)
func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) {
return Dial()(ctx, addr)
}

// Dial produces a vtgateconn.DialerFunc with custom options.
func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc {
return func(address string) (vtgateconn.Impl, error) {
return func(ctx context.Context, address string) (vtgateconn.Impl, error) {
opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name)
if err != nil {
return nil, err
}

opts = append(opts, opt)

cc, err := grpcclient.Dial(address, grpcclient.FailFast(false), opts...)
cc, err := grpcclient.DialContext(ctx, address, grpcclient.FailFast(false), opts...)
if err != nil {
return nil, err
}
Expand All @@ -99,6 +99,14 @@ func Dial(opts ...grpc.DialOption) vtgateconn.DialerFunc {
}
}

// DialWithOpts allows for custom dial options to be set on a vtgateConn.
//
// Deprecated: the context parameter cannot be used by the returned
// vtgateconn.DialerFunc and thus has no effect. Use Dial instead.
func DialWithOpts(_ context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc {
return Dial(opts...)
}

func (conn *vtgateConn) Execute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
request := &vtgatepb.ExecuteRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Expand Down
Loading
Loading