diff --git a/api/buf.lock b/api/buf.lock index 8c58811..e908553 100644 --- a/api/buf.lock +++ b/api/buf.lock @@ -4,7 +4,4 @@ deps: - remote: buf.build owner: googleapis repository: googleapis - branch: main - commit: c348fb6ca1774f95bdabe20c7b379d50 - digest: b1-liJZ9g7gOLHZca7G3Fu-Z4-9VKgrE3ci0lXsuG8hsVY= - create_time: 2022-03-17T15:04:57.656498Z + commit: d1263fe26f8e430a967dc22a4d0cad18 diff --git a/internal/api/minekube/connect/v1alpha1/watch_service.pb.go b/internal/api/minekube/connect/v1alpha1/watch_service.pb.go index 11cbec1..c996a48 100644 --- a/internal/api/minekube/connect/v1alpha1/watch_service.pb.go +++ b/internal/api/minekube/connect/v1alpha1/watch_service.pb.go @@ -27,11 +27,12 @@ type WatchRequest struct { unknownFields protoimpl.UnknownFields // Sending this message rejects a session proposed by the WatchService. This message should be sent to inform - // the WatchService that the server will not try to make a take the proposed session. The only purpose of - // this message is to provide quicker feedback to the player that he will not be connected with an optional - // localized reason. See https://github.com/grpc/grpc/blob/master/src/proto/grpc/status/status.proto. - // If the session is not rejected the watcher should establish the connection for the proposed session. - // If neither of these actions happen the proposal times out out and the player receives a connection timeout + // the WatchService that the server declines the proposed session. The purpose of + // this message is to provide quicker feedback to the system and provide feedback to the player that the connection + // won't be created. The rejection status can optionally contain a localized reason that should + // be displayed to the player. See https://github.com/grpc/grpc/blob/master/src/proto/grpc/status/status.proto. + // If the session is not rejected the watcher/client should establish the connection for the proposed session. + // If neither of these actions happen the proposal times out and the player receives a connection timeout // error indicating that the endpoint is currently unavailable. SessionRejection *SessionRejection `protobuf:"bytes,1,opt,name=session_rejection,json=sessionRejection,proto3" json:"session_rejection,omitempty"` } diff --git a/ws/client.go b/ws/client.go index 32f4727..0216c95 100644 --- a/ws/client.go +++ b/ws/client.go @@ -3,11 +3,9 @@ package ws import ( "context" "errors" - "fmt" "io" "net/http" - "google.golang.org/grpc/metadata" "nhooyr.io/websocket" "nhooyr.io/websocket/wspb" @@ -30,6 +28,7 @@ type ClientOptions struct { type HandshakeResponse func(ctx context.Context, res *http.Response) (context.Context, error) // Tunnel implements connect.Tunneler and creates a connection over a WebSocket. +// On error a http.Response may be provided by DialErrorResponse. func (o ClientOptions) Tunnel(ctx context.Context) (connect.Tunnel, error) { ctx, ws, err := o.dial(ctx) if err != nil { @@ -51,6 +50,7 @@ func (o ClientOptions) Tunnel(ctx context.Context) (connect.Tunnel, error) { } // Watch implements connect.Watcher and watches for session proposals. +// On error a http.Response may be provided by DialErrorResponse. func (o ClientOptions) Watch(ctx context.Context, propose connect.ReceiveProposal) error { ctx, ws, err := o.dial(ctx) if err != nil { @@ -90,41 +90,5 @@ func (o ClientOptions) Watch(ctx context.Context, propose connect.ReceiveProposa return err } -func (o *ClientOptions) dial(ctx context.Context) (context.Context, *websocket.Conn, error) { - if o.URL == "" { - return nil, nil, errors.New("missing websocket url") - } - - // Add metadata to websocket handshake request header - md, _ := metadata.FromOutgoingContext(ctx) - if o.DialOptions.HTTPHeader == nil { - o.DialOptions.HTTPHeader = http.Header(md) - } else { - header := metadata.Join(metadata.MD(o.DialOptions.HTTPHeader), md) - o.DialOptions.HTTPHeader = http.Header(header) - } - - // Dial service - if o.DialContext == nil { - o.DialContext = ctx - } - ws, res, err := websocket.Dial(o.DialContext, o.URL, &o.DialOptions) - if err != nil { - return nil, nil, fmt.Errorf("error handshaking with websocket server: %w", err) - } - - // Callback for handshake response - if o.Handshake != nil { - ctx, err = o.Handshake(ctx, res) - if err != nil { - _ = ws.Close(websocket.StatusNormalClosure, fmt.Sprintf( - "handshake response rejected: %v", err)) - return nil, nil, err - } - } - - return ctx, ws, nil -} - var _ connect.Tunneler = (*ClientOptions)(nil) var _ connect.Watcher = (*ClientOptions)(nil) diff --git a/ws/dial.go b/ws/dial.go new file mode 100644 index 0000000..4234cad --- /dev/null +++ b/ws/dial.go @@ -0,0 +1,77 @@ +package ws + +import ( + "context" + "errors" + "fmt" + "net/http" + + "google.golang.org/grpc/metadata" + "nhooyr.io/websocket" +) + +// DialErrorResponse returns the HTTP response from the WebSocket handshake error, if any. +func DialErrorResponse(err error) (*http.Response, bool) { + var e *dialErr + if errors.As(err, &e) { + return e.res, true + } + return nil, false +} + +func (o *ClientOptions) dial(ctx context.Context) (context.Context, *websocket.Conn, error) { + if o.URL == "" { + return nil, nil, errors.New("missing websocket url") + } + + header := metadata.Join( + metadata.MD(o.DialOptions.HTTPHeader), + mdFromContext(o.DialContext), + mdFromContext(ctx), + ) + if o.DialContext != nil { + ctx = o.DialContext + } + + // Dial service + ws, res, err := websocket.Dial(ctx, o.URL, &websocket.DialOptions{ + HTTPClient: o.DialOptions.HTTPClient, + HTTPHeader: http.Header(header), + Subprotocols: o.DialOptions.Subprotocols, + CompressionMode: o.DialOptions.CompressionMode, + CompressionThreshold: o.DialOptions.CompressionThreshold, + }) + if err != nil { + if res != nil { + err = &dialErr{error: err, res: res} + } + return nil, nil, fmt.Errorf("error handshaking with websocket server: %w", err) + } + + // Callback for handshake response + if o.Handshake != nil { + ctx, err = o.Handshake(ctx, res) + if err != nil { + _ = ws.Close(websocket.StatusNormalClosure, fmt.Sprintf( + "handshake response rejected: %v", err)) + return nil, nil, err + } + } + + return ctx, ws, nil +} + +type dialErr struct { + error + res *http.Response +} + +func (e *dialErr) Error() string { + return fmt.Sprintf("%s (%d): %v", e.res.Status, e.res.StatusCode, e.error) +} +func (e *dialErr) Unwrap() error { return e.error } + +func mdFromContext(ctx context.Context) metadata.MD { + md, _ := metadata.FromOutgoingContext(ctx) + return md +}