Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3333 Fix default auth source for auth specified via ClientOptions [master] #1798

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 7 additions & 41 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,34 +215,13 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
}

if args.Auth != nil {
var oidcMachineCallback auth.OIDCCallback
if args.Auth.OIDCMachineCallback != nil {
oidcMachineCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := args.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(oargs))
return (*driver.OIDCCredential)(cred), err
}
}

var oidcHumanCallback auth.OIDCCallback
if args.Auth.OIDCHumanCallback != nil {
oidcHumanCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := args.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(oargs))
return (*driver.OIDCCredential)(cred), err
}
}

// Create an authenticator for the client
client.authenticator, err = auth.CreateAuthenticator(args.Auth.AuthMechanism, &auth.Cred{
Source: args.Auth.AuthSource,
Username: args.Auth.Username,
Password: args.Auth.Password,
PasswordSet: args.Auth.PasswordSet,
Props: args.Auth.AuthMechanismProperties,
OIDCMachineCallback: oidcMachineCallback,
OIDCHumanCallback: oidcHumanCallback,
}, args.HTTPClient)
client.authenticator, err = auth.CreateAuthenticator(
args.Auth.AuthMechanism,
topology.ConvertCreds(args.Auth),
args.HTTPClient,
)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating authenticator: %w", err)
}
}

Expand Down Expand Up @@ -274,20 +253,7 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
return client, nil
}

// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent
// public type *options.OIDCArgs.
func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs {
if args == nil {
return nil
}
return &options.OIDCArgs{
Version: args.Version,
IDPInfo: (*options.IDPInfo)(args.IDPInfo),
RefreshToken: args.RefreshToken,
}
}

// connect initializes the Client by starting background monitoring goroutines.
// Connect initializes the Client by starting background monitoring goroutines.
// If the Client was created using the NewClient function, this method must be called before a Client can be used.
//
// Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main
Expand Down
76 changes: 0 additions & 76 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"math"
"os"
"reflect"
"testing"
"time"

Expand All @@ -20,13 +19,11 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/mongoutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/tag"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
Expand Down Expand Up @@ -519,76 +516,3 @@ func TestClient(t *testing.T) {
assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error())
})
}

// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs
// into an options.OIDCArgs.
func TestConvertOIDCArgs(t *testing.T) {
refreshToken := "test refresh token"

testCases := []struct {
desc string
args *driver.OIDCArgs
}{
{
desc: "populated args",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: &driver.IDPInfo{
Issuer: "test issuer",
ClientID: "test client ID",
RequestScopes: []string{"test scope 1", "test scope 2"},
},
RefreshToken: &refreshToken,
},
},
{
desc: "nil",
args: nil,
},
{
desc: "nil IDPInfo and RefreshToken",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: nil,
RefreshToken: nil,
},
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.

t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

got := convertOIDCArgs(tc.args)

if tc.args == nil {
assert.Nil(t, got, "expected nil when input is nil")
return
}

require.Equal(t,
3,
reflect.ValueOf(*tc.args).NumField(),
"expected the driver.OIDCArgs struct to have exactly 3 fields")
require.Equal(t,
3,
reflect.ValueOf(*got).NumField(),
"expected the options.OIDCArgs struct to have exactly 3 fields")

assert.Equal(t,
tc.args.Version,
got.Version,
"expected Version field to be equal")
assert.EqualValues(t,
tc.args.IDPInfo,
got.IDPInfo,
"expected IDPInfo field to be convertible to equal values")
assert.Equal(t,
tc.args.RefreshToken,
got.RefreshToken,
"expected RefreshToken field to be equal")
})
}
}
6 changes: 3 additions & 3 deletions mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ type ContextDialer interface {
// The SERVICE_HOST and CANONICALIZE_HOST_NAME properties must not be used at the same time on Linux and Darwin
// systems.
//
// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-X509,
// GSSAPI, and PLAIN and "admin" for all other mechanisms. This can also be set through the "authSource" URI option
// (e.g. "authSource=otherDb").
// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-AWS,
// MONGODB-OIDC, MONGODB-X509, GSSAPI, and PLAIN. It defaults to "admin" for all other auth mechanisms. This can
// also be set through the "authSource" URI option (e.g. "authSource=otherDb").
//
// Username: the username for authentication. This can also be set through the URI as a username:password pair before
// the first @ character. For example, a URI for user "user", password "pwd", and host "localhost:27017" would be
Expand Down
2 changes: 2 additions & 0 deletions x/mongo/driver/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)

const sourceExternal = "$external"

// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error)

Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/gssapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
const GSSAPI = "GSSAPI"

func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}

Expand Down Expand Up @@ -57,7 +57,7 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, cfg, "$external", client)
return ConductSaslConversation(ctx, cfg, sourceExternal, client)
}

// Reauth reauthenticates the connection.
Expand Down
7 changes: 2 additions & 5 deletions x/mongo/driver/auth/mongodbaws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ import (
const MongoDBAWS = "MONGODB-AWS"

func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
}
if httpClient == nil {
return nil, errors.New("httpClient must not be nil")
}
return &MongoDBAWSAuthenticator{
source: cred.Source,
credentials: &credproviders.StaticProvider{
Value: credentials.Value{
ProviderName: cred.Source,
AccessKeyID: cred.Username,
SecretAccessKey: cred.Password,
SessionToken: cred.Props["AWS_SESSION_TOKEN"],
Expand All @@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica

// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
source string
credentials *credproviders.StaticProvider
httpClient *http.Client
}
Expand All @@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConf
credentials: providers.Cred,
},
}
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
err := ConductSaslConversation(ctx, cfg, sourceExternal, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
Expand Down
6 changes: 5 additions & 1 deletion x/mongo/driver/auth/mongodbcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ import (
const MONGODBCR = "MONGODB-CR"

func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
return &MongoDBCRAuthenticator{
DB: cred.Source,
DB: source,
Username: cred.Username,
Password: cred.Password,
}, nil
Expand Down
13 changes: 8 additions & 5 deletions x/mongo/driver/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) {
}

func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != sourceExternal {
return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil)
}
if cred.Password != "" {
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
}
Expand Down Expand Up @@ -446,7 +449,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) e
oa.mu.Unlock()

if cachedAccessToken != "" {
err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{
err = ConductSaslConversation(ctx, cfg, sourceExternal, &oidcOneStep{
userName: oa.userName,
accessToken: cachedAccessToken,
})
Expand Down Expand Up @@ -506,7 +509,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo
return ConductSaslConversation(
subCtx,
cfg,
"$external",
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
Expand All @@ -515,7 +518,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo
conn: cfg.Connection,
oa: oa,
}
return ConductSaslConversation(subCtx, cfg, "$external", ots)
return ConductSaslConversation(subCtx, cfg, sourceExternal, ots)
}

func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.AuthConfig, machineCallback OIDCCallback) error {
Expand All @@ -536,7 +539,7 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.Auth
return ConductSaslConversation(
ctx,
cfg,
"$external",
sourceExternal,
&oidcOneStep{accessToken: accessToken},
)
}
Expand All @@ -550,5 +553,5 @@ func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConvers
return nil, nil // Skip speculative auth.
}

return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil
return newSaslConversation(&oidcOneStep{accessToken: accessToken}, sourceExternal, true), nil
}
17 changes: 16 additions & 1 deletion x/mongo/driver/auth/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ import (
const PLAIN = "PLAIN"

func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
// TODO(GODRIVER-3317): The PLAIN specification says about auth source:
//
// "MUST be specified. Defaults to the database name if supplied on the
// connection string or $external."
//
// We should actually pass through the auth source, not always pass
// $external. If it's empty, we should default to $external.
//
// For example:
//
// source := cred.Source
// if source == "" {
// source = "$external"
// }
//
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
Expand All @@ -31,7 +46,7 @@ type PlainAuthenticator struct {

// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error {
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
return ConductSaslConversation(ctx, cfg, sourceExternal, &plainSaslClient{
username: a.Username,
password: a.Password,
})
Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/auth/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ package auth_test

import (
"context"
"encoding/base64"
"strings"
"testing"

"encoding/base64"

"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
Expand Down
12 changes: 10 additions & 2 deletions x/mongo/driver/auth/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ var (
)

func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
Expand All @@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
source: source,
client: client,
}, nil
}

func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError("error SASLprepping password", err)
Expand All @@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
source: source,
client: client,
}, nil
}
Expand Down
Loading
Loading