Skip to content
This repository has been archived by the owner on Sep 15, 2022. It is now read-only.

Commit

Permalink
Fixed passing headers through gRPC proxies.
Browse files Browse the repository at this point in the history
Signed-off-by: Bartek Plotka <bwplotka@gmail.com>
  • Loading branch information
bwplotka committed Dec 22, 2017
1 parent ef13c8a commit 132fb9b
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 37 deletions.
5 changes: 2 additions & 3 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Gopkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

[[constraint]]
name = "github.com/mwitkow/grpc-proxy"
revision = "67591eb23c48346a480470e462289835d96f70da"

[[constraint]]
name = "github.com/oklog/oklog"
Expand Down
16 changes: 16 additions & 0 deletions pkg/grpc/metadata/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package grpc_metadata

import (
"context"

"google.golang.org/grpc/metadata"
)

func CloneIncomingToOutgoing(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.Pairs()
}
// Copy the inbound metadata explicitly.
return metadata.NewOutgoingContext(ctx, md.Copy())
}
13 changes: 7 additions & 6 deletions pkg/kedge/grpc/director/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/improbable-eng/kedge/pkg/grpc/metadata"
"github.com/improbable-eng/kedge/pkg/kedge/grpc/backendpool"
"github.com/improbable-eng/kedge/pkg/kedge/grpc/director/router"
"github.com/mwitkow/grpc-proxy/proxy"
Expand All @@ -17,17 +18,17 @@ import (

// New builds a StreamDirector based off a backend pool and a router.
func New(pool backendpool.Pool, router router.Router) proxy.StreamDirector {
return func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
return func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) {
beName, err := router.Route(ctx, fullMethodName)
if err != nil {
return nil, err
return ctx, nil, err
}

ctx = grpc_metadata.CloneIncomingToOutgoing(ctx)

grpc_ctxtags.Extract(ctx).Set("grpc.proxy.backend", beName)
cc, err := pool.Conn(beName)
if err != nil {
return nil, err
}
return cc, nil
return ctx, cc, err
}
}

Expand Down
60 changes: 40 additions & 20 deletions pkg/kedge/grpc/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/fortytw2/leaktest"
"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/improbable-eng/go-srvlb/srv"
"github.com/improbable-eng/kedge/pkg/kedge/grpc/backendpool"
"github.com/improbable-eng/kedge/pkg/kedge/grpc/client"
Expand Down Expand Up @@ -93,9 +94,10 @@ var routeConfigs = []*pb_route.Route{
}

type unknownResponse struct {
Addr string `protobuf:"bytes,1,opt,name=addr,json=value"`
Method string `protobuf:"bytes,2,opt,name=method"`
Backend string `protobuf:"bytes,3,opt,name=backend"`
Addr string `protobuf:"bytes,1,opt,name=addr,json=value"`
Method string `protobuf:"bytes,2,opt,name=method"`
Backend string `protobuf:"bytes,3,opt,name=backend"`
AuthorizationToken string `protobuf:"bytes,4,opt,name=auth"`
}

func (m *unknownResponse) Reset() { *m = unknownResponse{} }
Expand All @@ -108,7 +110,8 @@ func unknownPingbackHandler(backendName string, serverAddr string) grpc.StreamHa
if !ok {
return fmt.Errorf("handler should have access to transport info")
}
return stream.SendMsg(&unknownResponse{Method: tr.Method(), Addr: serverAddr, Backend: backendName})
md := metautils.ExtractIncoming(stream.Context())
return stream.SendMsg(&unknownResponse{Method: tr.Method(), Addr: serverAddr, Backend: backendName, AuthorizationToken: md.Get("authorization")})
}
}

Expand Down Expand Up @@ -213,6 +216,21 @@ func (s *BackendPoolIntegrationTestSuite) Lookup(domainName string) ([]*srv.Targ

const testToken = "test-token"

type tokenCreds struct {
token string
header string
}

func (c *tokenCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
c.header: c.token,
}, nil
}

func (c *tokenCreds) RequireTransportSecurity() bool {
return false
}

func (s *BackendPoolIntegrationTestSuite) SetupSuite() {
var err error
s.proxyListener, err = net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -245,7 +263,7 @@ func (s *BackendPoolIntegrationTestSuite) SetupSuite() {
s.kedgeMapper = kedge_map.Single(proxyUrl)
s.proxyConn, err = grpc.Dial(fmt.Sprintf("localhost:%s", proxyPort),
grpc.WithTransportCredentials(credentials.NewTLS(s.tlsConfigForTest())),
grpc.WithPerRPCCredentials(&tokenCreds{token: testToken}),
grpc.WithPerRPCCredentials(&tokenCreds{token: "bearer " + testToken, header: "proxy-authorization"}),
grpc.WithBlock(),
)
require.NoError(s.T(), err, "dialing the proxy on a conn *must not* fail")
Expand All @@ -272,20 +290,6 @@ func (s *BackendPoolIntegrationTestSuite) SimpleCtx() context.Context {
return ctx
}

type tokenCreds struct {
token string
}

func (c *tokenCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
"proxy-authorization": fmt.Sprintf("bearer %s", c.token),
}, nil
}

func (c *tokenCreds) RequireTransportSecurity() bool {
return false
}

func (s *BackendPoolIntegrationTestSuite) TestCallToNonSecureBackend() {
resp := &unknownResponse{}
err := grpc.Invoke(s.SimpleCtx(), "/hand_rolled.non_secure.SomeService/Method", &unknownResponse{}, resp, s.proxyConn)
Expand All @@ -308,12 +312,28 @@ func (s *BackendPoolIntegrationTestSuite) TestClientDialSecureToNonSecureBackend
context.TODO(),
"secure.ext.test.local",
s.tlsConfigForTest(),
s.kedgeMapper, grpc.WithPerRPCCredentials(&tokenCreds{token: testToken}),
s.kedgeMapper, grpc.WithPerRPCCredentials(&tokenCreds{token: "bearer " + testToken, header: "proxy-authorization"}),
)
require.NoError(s.T(), err, "dialing through kedge must succeed")
defer cc.Close()
resp := s.invokeUnknownHandlerPingbackAndAssert("/hand_rolled.common.NonSpecificService/Method", cc)
assert.Equal(s.T(), "secure_localbackends", resp.Backend)
}

func (s *BackendPoolIntegrationTestSuite) TestClientDialSecureToNonSecureBackend_BackendAuth() {
cc, err := kedge_grpc.DialThroughKedge(
context.TODO(),
"secure.ext.test.local",
s.tlsConfigForTest(),
s.kedgeMapper,
grpc.WithPerRPCCredentials(&tokenCreds{token: "bearer " + testToken, header: "proxy-authorization"}),
grpc.WithPerRPCCredentials(&tokenCreds{token: "bearer test-backend-token", header: "authorization"}),
)
require.NoError(s.T(), err, "dialing through kedge must succeed")
defer cc.Close()
resp := s.invokeUnknownHandlerPingbackAndAssert("/hand_rolled.common.NonSpecificService/Method", cc)
assert.Equal(s.T(), "secure_localbackends", resp.Backend)
assert.Equal(s.T(), "bearer test-backend-token", resp.AuthorizationToken)
}

func (s *BackendPoolIntegrationTestSuite) invokeUnknownHandlerPingbackAndAssert(fullMethod string, conn *grpc.ClientConn) *unknownResponse {
Expand Down
57 changes: 54 additions & 3 deletions pkg/winch/grpc/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ func (s *WinchIntegrationSuite) SetupSuite() {
},
},
},
{
Type: &pb.Route_Direct{
Direct: &pb.DirectRoute{
Key: "resource1-authtest.ext.example.com",
Url: "https://" + moveToLocalhost(s.localSecureKedges.listeners[0].Addr().String()),
},
},
},
{
ProxyAuth: "proxy-access1",
Type: &pb.Route_Direct{
Expand Down Expand Up @@ -221,14 +229,18 @@ func (s *WinchIntegrationSuite) SimpleCtx() context.Context {
}

// dialThroughWinch creates plain, insecure, local connection to winch.
func (s *WinchIntegrationSuite) dialThroughWinch(targetAuthority string) (*grpc.ClientConn, error) {
func (s *WinchIntegrationSuite) dialThroughWinch(targetAuthority string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
proxyPort := s.winchListenerPlain.Addr().String()[strings.LastIndex(s.winchListenerPlain.Addr().String(), ":")+1:]

return grpc.Dial(
fmt.Sprintf("localhost:%s", proxyPort),
defaultOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithInsecure(),
grpc.WithAuthority(targetAuthority),
}
opts = append(defaultOpts, opts...)
return grpc.Dial(
fmt.Sprintf("localhost:%s", proxyPort),
opts...,
)
}

Expand Down Expand Up @@ -263,6 +275,45 @@ func (s *WinchIntegrationSuite) TestCallKedgeThroughWinch_DirectRoute_ValidAuth(
assert.Equal(s.T(), "bearer test-token", resp.BackendAuth)
}

type tokenCreds struct {
token string
header string
}

func (c *tokenCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
c.header: c.token,
}, nil
}

func (c *tokenCreds) RequireTransportSecurity() bool {
return false
}

func (s *WinchIntegrationSuite) TestCallKedgeThroughWinch_DirectRoute_WinchPassesValidAuth() {
cc, err := s.dialThroughWinch("resource1-authtest.ext.example.com")
require.NoError(s.T(), err, "dialing the winch *must not* fail")
defer func() {
cc.Close()
time.Sleep(10 * time.Millisecond)
}()

resp := &unknownResponse{}
err = grpc.Invoke(
s.SimpleCtx(),
"/test.SomeService/Method",
&unknownResponse{},
resp,
cc,
grpc.PerRPCCredentials(&tokenCreds{token: "bearer test-token", header: "authorization"}),
)
require.NoError(s.T(), err, "no error on simple call")
assert.Equal(s.T(), "/test.SomeService/Method", resp.Method)
assert.Equal(s.T(), "0", resp.Backend)
assert.Equal(s.T(), "", resp.ProxyAuth)
assert.Equal(s.T(), "bearer test-token", resp.BackendAuth)
}

func (s *WinchIntegrationSuite) TestCallKedgeThroughWinch_DirectRoute2_ProxyAuth() {
cc, err := s.dialThroughWinch("resource2.ext.example.com")
require.NoError(s.T(), err, "dialing the winch *must not* fail")
Expand Down
14 changes: 9 additions & 5 deletions pkg/winch/grpc/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/improbable-eng/kedge/pkg/grpc/metadata"
"github.com/improbable-eng/kedge/pkg/http/header"
"github.com/improbable-eng/kedge/pkg/map"
"github.com/improbable-eng/kedge/pkg/tokenauth"
Expand All @@ -25,19 +26,19 @@ import (

// New builds a StreamDirector based off a backend pool and a router.
func New(mapper kedge_map.Mapper, config *tls.Config, debugMode bool) proxy.StreamDirector {
return func(ctx context.Context, fullMethodName string) (*grpc.ClientConn, error) {
return func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) {
md := metautils.ExtractIncoming(ctx)
targetAuthority := md.Get(":authority")
if targetAuthority == "" {
return nil, errors.New("No :authority header. Cannot find the host")
return ctx, nil, errors.New("No :authority header. Cannot find the host")
}

route, err := mapper.Map(targetAuthority, "")
if err != nil {
if err == kedge_map.ErrNotKedgeDestination {
return nil, status.Error(codes.Unimplemented, err.Error())
return ctx, nil, status.Error(codes.Unimplemented, err.Error())
}
return nil, err
return ctx, nil, err
}

tags := grpc_ctxtags.Extract(ctx)
Expand All @@ -50,6 +51,8 @@ func New(mapper kedge_map.Mapper, config *tls.Config, debugMode bool) proxy.Stre
tags.Set(header.RequestKedgeForceInfoLogs, os.ExpandEnv("winch-$USER"))
}

ctx = grpc_metadata.CloneIncomingToOutgoing(ctx)

transportCreds := credentials.NewTLS(config)
// Make sure authority is ok.
transportCreds = &proxiedTlsCredentials{TransportCredentials: transportCreds, authorityNameOverride: targetAuthority}
Expand All @@ -68,11 +71,12 @@ func New(mapper kedge_map.Mapper, config *tls.Config, debugMode bool) proxy.Stre
dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(newTokenCredentials(route.BackendAuth, "authorization")))
}

return grpc.DialContext(
conn, err := grpc.DialContext(
ctx,
net.JoinHostPort(route.URL.Hostname(), route.URL.Port()),
dialOpts...,
)
return ctx, conn, err
}
}

Expand Down

0 comments on commit 132fb9b

Please sign in to comment.