Skip to content

Commit

Permalink
GODRIVER-3215 Fix default auth source for auth specified via ClientOp…
Browse files Browse the repository at this point in the history
…tions [master] (#1798)

Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
  • Loading branch information
blink1073 and matthewdale authored Sep 10, 2024
1 parent b0caeba commit 9e7ccb0
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 189 deletions.
48 changes: 7 additions & 41 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,34 +215,13 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
}

if args.Auth != nil {
var oidcMachineCallback auth.OIDCCallback
if args.Auth.OIDCMachineCallback != nil {
oidcMachineCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := args.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(oargs))
return (*driver.OIDCCredential)(cred), err
}
}

var oidcHumanCallback auth.OIDCCallback
if args.Auth.OIDCHumanCallback != nil {
oidcHumanCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := args.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(oargs))
return (*driver.OIDCCredential)(cred), err
}
}

// Create an authenticator for the client
client.authenticator, err = auth.CreateAuthenticator(args.Auth.AuthMechanism, &auth.Cred{
Source: args.Auth.AuthSource,
Username: args.Auth.Username,
Password: args.Auth.Password,
PasswordSet: args.Auth.PasswordSet,
Props: args.Auth.AuthMechanismProperties,
OIDCMachineCallback: oidcMachineCallback,
OIDCHumanCallback: oidcHumanCallback,
}, args.HTTPClient)
client.authenticator, err = auth.CreateAuthenticator(
args.Auth.AuthMechanism,
topology.ConvertCreds(args.Auth),
args.HTTPClient,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating authenticator: %w", err)
}
}

Expand Down Expand Up @@ -274,20 +253,7 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
return client, nil
}

// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent
// public type *options.OIDCArgs.
func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs {
if args == nil {
return nil
}
return &options.OIDCArgs{
Version: args.Version,
IDPInfo: (*options.IDPInfo)(args.IDPInfo),
RefreshToken: args.RefreshToken,
}
}

// connect initializes the Client by starting background monitoring goroutines.
// Connect initializes the Client by starting background monitoring goroutines.
// If the Client was created using the NewClient function, this method must be called before a Client can be used.
//
// Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main
Expand Down
76 changes: 0 additions & 76 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"math"
"os"
"reflect"
"testing"
"time"

Expand All @@ -20,13 +19,11 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/mongoutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/tag"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
Expand Down Expand Up @@ -519,76 +516,3 @@ func TestClient(t *testing.T) {
assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error())
})
}

// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs
// into an options.OIDCArgs.
func TestConvertOIDCArgs(t *testing.T) {
refreshToken := "test refresh token"

testCases := []struct {
desc string
args *driver.OIDCArgs
}{
{
desc: "populated args",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: &driver.IDPInfo{
Issuer: "test issuer",
ClientID: "test client ID",
RequestScopes: []string{"test scope 1", "test scope 2"},
},
RefreshToken: &refreshToken,
},
},
{
desc: "nil",
args: nil,
},
{
desc: "nil IDPInfo and RefreshToken",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: nil,
RefreshToken: nil,
},
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.

t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

got := convertOIDCArgs(tc.args)

if tc.args == nil {
assert.Nil(t, got, "expected nil when input is nil")
return
}

require.Equal(t,
3,
reflect.ValueOf(*tc.args).NumField(),
"expected the driver.OIDCArgs struct to have exactly 3 fields")
require.Equal(t,
3,
reflect.ValueOf(*got).NumField(),
"expected the options.OIDCArgs struct to have exactly 3 fields")

assert.Equal(t,
tc.args.Version,
got.Version,
"expected Version field to be equal")
assert.EqualValues(t,
tc.args.IDPInfo,
got.IDPInfo,
"expected IDPInfo field to be convertible to equal values")
assert.Equal(t,
tc.args.RefreshToken,
got.RefreshToken,
"expected RefreshToken field to be equal")
})
}
}
6 changes: 3 additions & 3 deletions mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ type ContextDialer interface {
// The SERVICE_HOST and CANONICALIZE_HOST_NAME properties must not be used at the same time on Linux and Darwin
// systems.
//
// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-X509,
// GSSAPI, and PLAIN and "admin" for all other mechanisms. This can also be set through the "authSource" URI option
// (e.g. "authSource=otherDb").
// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-AWS,
// MONGODB-OIDC, MONGODB-X509, GSSAPI, and PLAIN. It defaults to "admin" for all other auth mechanisms. This can
// also be set through the "authSource" URI option (e.g. "authSource=otherDb").
//
// Username: the username for authentication. This can also be set through the URI as a username:password pair before
// the first @ character. For example, a URI for user "user", password "pwd", and host "localhost:27017" would be
Expand Down
2 changes: 2 additions & 0 deletions x/mongo/driver/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)

const sourceExternal = "$external"

// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error)

Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/gssapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
const GSSAPI = "GSSAPI"

func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}

Expand Down Expand Up @@ -57,7 +57,7 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, cfg, "$external", client)
return ConductSaslConversation(ctx, cfg, sourceExternal, client)
}

// Reauth reauthenticates the connection.
Expand Down
7 changes: 2 additions & 5 deletions x/mongo/driver/auth/mongodbaws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ import (
const MongoDBAWS = "MONGODB-AWS"

func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
}
if httpClient == nil {
return nil, errors.New("httpClient must not be nil")
}
return &MongoDBAWSAuthenticator{
source: cred.Source,
credentials: &credproviders.StaticProvider{
Value: credentials.Value{
ProviderName: cred.Source,
AccessKeyID: cred.Username,
SecretAccessKey: cred.Password,
SessionToken: cred.Props["AWS_SESSION_TOKEN"],
Expand All @@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica

// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
source string
credentials *credproviders.StaticProvider
httpClient *http.Client
}
Expand All @@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConf
credentials: providers.Cred,
},
}
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
err := ConductSaslConversation(ctx, cfg, sourceExternal, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
Expand Down
6 changes: 5 additions & 1 deletion x/mongo/driver/auth/mongodbcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ import (
const MONGODBCR = "MONGODB-CR"

func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
return &MongoDBCRAuthenticator{
DB: cred.Source,
DB: source,
Username: cred.Username,
Password: cred.Password,
}, nil
Expand Down
13 changes: 8 additions & 5 deletions x/mongo/driver/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) {
}

func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil)
}
if cred.Password != "" {
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
}
Expand Down Expand Up @@ -446,7 +449,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) e
oa.mu.Unlock()

if cachedAccessToken != "" {
err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{
err = ConductSaslConversation(ctx, cfg, sourceExternal, &oidcOneStep{
userName: oa.userName,
accessToken: cachedAccessToken,
})
Expand Down Expand Up @@ -506,7 +509,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo
return ConductSaslConversation(
subCtx,
cfg,
"$external",
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
Expand All @@ -515,7 +518,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo
conn: cfg.Connection,
oa: oa,
}
return ConductSaslConversation(subCtx, cfg, "$external", ots)
return ConductSaslConversation(subCtx, cfg, sourceExternal, ots)
}

func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.AuthConfig, machineCallback OIDCCallback) error {
Expand All @@ -536,7 +539,7 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.Auth
return ConductSaslConversation(
ctx,
cfg,
"$external",
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
Expand All @@ -550,5 +553,5 @@ func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConvers
return nil, nil // Skip speculative auth.
}

return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil
return newSaslConversation(&oidcOneStep{accessToken: accessToken}, sourceExternal, true), nil
}
17 changes: 16 additions & 1 deletion x/mongo/driver/auth/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ import (
const PLAIN = "PLAIN"

func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
// TODO(GODRIVER-3317): The PLAIN specification says about auth source:
//
// "MUST be specified. Defaults to the database name if supplied on the
// connection string or $external."
//
// We should actually pass through the auth source, not always pass
// $external. If it's empty, we should default to $external.
//
// For example:
//
// source := cred.Source
// if source == "" {
// source = "$external"
// }
//
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
Expand All @@ -31,7 +46,7 @@ type PlainAuthenticator struct {

// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
return ConductSaslConversation(ctx, cfg, sourceExternal, &plainSaslClient{
username: a.Username,
password: a.Password,
})
Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/auth/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ package auth_test

import (
"context"
"encoding/base64"
"strings"
"testing"

"encoding/base64"

"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
Expand Down
12 changes: 10 additions & 2 deletions x/mongo/driver/auth/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ var (
)

func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
Expand All @@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
source: source,
client: client,
}, nil
}

func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError("error SASLprepping password", err)
Expand All @@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
source: source,
client: client,
}, nil
}
Expand Down
Loading

0 comments on commit 9e7ccb0

Please sign in to comment.