Skip to content

Commit

Permalink
Add DialErrorResponse func
Browse files Browse the repository at this point in the history
  • Loading branch information
robinbraemer committed Oct 22, 2022
1 parent d9a5d99 commit c7715bd
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 47 deletions.
5 changes: 1 addition & 4 deletions api/buf.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,4 @@ deps:
- remote: buf.build
owner: googleapis
repository: googleapis
branch: main
commit: c348fb6ca1774f95bdabe20c7b379d50
digest: b1-liJZ9g7gOLHZca7G3Fu-Z4-9VKgrE3ci0lXsuG8hsVY=
create_time: 2022-03-17T15:04:57.656498Z
commit: d1263fe26f8e430a967dc22a4d0cad18
11 changes: 6 additions & 5 deletions internal/api/minekube/connect/v1alpha1/watch_service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 2 additions & 38 deletions ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package ws
import (
"context"
"errors"
"fmt"
"io"
"net/http"

"google.golang.org/grpc/metadata"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wspb"

Expand All @@ -30,6 +28,7 @@ type ClientOptions struct {
type HandshakeResponse func(ctx context.Context, res *http.Response) (context.Context, error)

// Tunnel implements connect.Tunneler and creates a connection over a WebSocket.
// On error a http.Response may be provided by DialErrorResponse.
func (o ClientOptions) Tunnel(ctx context.Context) (connect.Tunnel, error) {
ctx, ws, err := o.dial(ctx)
if err != nil {
Expand All @@ -51,6 +50,7 @@ func (o ClientOptions) Tunnel(ctx context.Context) (connect.Tunnel, error) {
}

// Watch implements connect.Watcher and watches for session proposals.
// On error a http.Response may be provided by DialErrorResponse.
func (o ClientOptions) Watch(ctx context.Context, propose connect.ReceiveProposal) error {
ctx, ws, err := o.dial(ctx)
if err != nil {
Expand Down Expand Up @@ -90,41 +90,5 @@ func (o ClientOptions) Watch(ctx context.Context, propose connect.ReceiveProposa
return err
}

func (o *ClientOptions) dial(ctx context.Context) (context.Context, *websocket.Conn, error) {
if o.URL == "" {
return nil, nil, errors.New("missing websocket url")
}

// Add metadata to websocket handshake request header
md, _ := metadata.FromOutgoingContext(ctx)
if o.DialOptions.HTTPHeader == nil {
o.DialOptions.HTTPHeader = http.Header(md)
} else {
header := metadata.Join(metadata.MD(o.DialOptions.HTTPHeader), md)
o.DialOptions.HTTPHeader = http.Header(header)
}

// Dial service
if o.DialContext == nil {
o.DialContext = ctx
}
ws, res, err := websocket.Dial(o.DialContext, o.URL, &o.DialOptions)
if err != nil {
return nil, nil, fmt.Errorf("error handshaking with websocket server: %w", err)
}

// Callback for handshake response
if o.Handshake != nil {
ctx, err = o.Handshake(ctx, res)
if err != nil {
_ = ws.Close(websocket.StatusNormalClosure, fmt.Sprintf(
"handshake response rejected: %v", err))
return nil, nil, err
}
}

return ctx, ws, nil
}

var _ connect.Tunneler = (*ClientOptions)(nil)
var _ connect.Watcher = (*ClientOptions)(nil)
77 changes: 77 additions & 0 deletions ws/dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package ws

import (
"context"
"errors"
"fmt"
"net/http"

"google.golang.org/grpc/metadata"
"nhooyr.io/websocket"
)

// DialErrorResponse returns the HTTP response from the WebSocket handshake error, if any.
func DialErrorResponse(err error) (*http.Response, bool) {
var e *dialErr
if errors.As(err, &e) {
return e.res, true
}
return nil, false
}

func (o *ClientOptions) dial(ctx context.Context) (context.Context, *websocket.Conn, error) {
if o.URL == "" {
return nil, nil, errors.New("missing websocket url")
}

header := metadata.Join(
metadata.MD(o.DialOptions.HTTPHeader),
mdFromContext(o.DialContext),
mdFromContext(ctx),
)
if o.DialContext != nil {
ctx = o.DialContext
}

// Dial service
ws, res, err := websocket.Dial(ctx, o.URL, &websocket.DialOptions{
HTTPClient: o.DialOptions.HTTPClient,
HTTPHeader: http.Header(header),
Subprotocols: o.DialOptions.Subprotocols,
CompressionMode: o.DialOptions.CompressionMode,
CompressionThreshold: o.DialOptions.CompressionThreshold,
})
if err != nil {
if res != nil {
err = &dialErr{error: err, res: res}
}
return nil, nil, fmt.Errorf("error handshaking with websocket server: %w", err)
}

// Callback for handshake response
if o.Handshake != nil {
ctx, err = o.Handshake(ctx, res)
if err != nil {
_ = ws.Close(websocket.StatusNormalClosure, fmt.Sprintf(
"handshake response rejected: %v", err))
return nil, nil, err
}
}

return ctx, ws, nil
}

type dialErr struct {
error
res *http.Response
}

func (e *dialErr) Error() string {
return fmt.Sprintf("%s (%d): %v", e.res.Status, e.res.StatusCode, e.error)
}
func (e *dialErr) Unwrap() error { return e.error }

func mdFromContext(ctx context.Context) metadata.MD {
md, _ := metadata.FromOutgoingContext(ctx)
return md
}

0 comments on commit c7715bd

Please sign in to comment.