From aaefcc6413dbb2d22a7595843488637cbaec81cb Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Fri, 9 Aug 2024 16:01:42 +0200 Subject: [PATCH] feat: use random port for OAuth2 callback (#371) --- cmd/cloudx/client/auth.go | 58 ++++++++++----------------------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/cmd/cloudx/client/auth.go b/cmd/cloudx/client/auth.go index 6642889c..a6b31325 100644 --- a/cmd/cloudx/client/auth.go +++ b/cmd/cloudx/client/auth.go @@ -8,7 +8,6 @@ import ( stderrors "errors" "fmt" "io" - "math/rand/v2" "net" "net/http" "net/url" @@ -17,9 +16,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/pkg/errors" "golang.org/x/oauth2" - "golang.org/x/sync/errgroup" cloud "github.com/ory/client-go" "github.com/ory/x/randx" @@ -178,27 +175,17 @@ func (h *CommandHelper) loginOAuth2(ctx context.Context) (*Config, error) { func (h *CommandHelper) oAuth2DanceWithServer(ctx context.Context, client *oauth2.Config) (token *oauth2.Token, err error) { var ( - l net.Listener state = randx.MustString(32, randx.AlphaNum) pkceVerifier = oauth2.GenerateVerifier() - ports = []int{12345, 15793, 17628, 19834, 23730, 27462, 34525, 36209, 42827, 46718, 49763, 51238, 52213, 57923, 59724, 60582, 62125, 65321, 49876, 54321, 59876, 60987, 62345, 63456, 64567, 65123, 65234, 65432, 65500, 65510, 65520, 65530} + serverErr = make(chan error) + serverToken = make(chan *oauth2.Token) ) - rand.Shuffle(len(ports), func(i, j int) { ports[i], ports[j] = ports[j], ports[i] }) - for _, port := range ports { - l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) - if err == nil { - client.RedirectURL = fmt.Sprintf("http://localhost:%d/callback", port) - break - } - } - if l == nil { - return nil, fmt.Errorf("failed to allocate port for OAuth2 callback handler, try again later: last error: %w", err) + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("failed to allocate port for OAuth2 callback handler, try again later: %w", err) } + client.RedirectURL = fmt.Sprintf("http://%s/callback", l.Addr().String()) - var ( - serverErr = make(chan error) - serverToken = make(chan *oauth2.Token) - ) srv := http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // for retries the user has to start from the beginning @@ -242,27 +229,8 @@ func (h *CommandHelper) oAuth2DanceWithServer(ctx context.Context, client *oauth redirectOK(w, r) }), } - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() (err error) { - if err := srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("failed to serve OAuth2 callback handler: %w", err) - } - return nil - }) - eg.Go(func() (err error) { - select { - case <-ctx.Done(): - err = ctx.Err() - case token = <-serverToken: - case err = <-serverErr: - } - ctx, cancel := context.WithDeadline(context.WithoutCancel(ctx), time.Now().Add(20*time.Second)) - defer cancel() - return stderrors.Join(err, srv.Shutdown(ctx)) - }) + go func() { _ = srv.Serve(l) }() + defer srv.Close() u := client.AuthCodeURL(state, oauth2.S256ChallengeOption(pkceVerifier), @@ -282,10 +250,14 @@ If no browser opened, visit the below page to continue: `, u) - if err := eg.Wait(); err != nil { - return nil, fmt.Errorf("failed to authenticate, please try again: %w", err) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case token := <-serverToken: + return token, nil + case err := <-serverErr: + return nil, err } - return token, nil } func redirectOK(w http.ResponseWriter, r *http.Request) {