Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: renew the session token when the token expires
Browse files Browse the repository at this point in the history
MqllR committed Dec 6, 2023
1 parent c1c7ee3 commit 6706169
Showing 2 changed files with 32 additions and 12 deletions.
28 changes: 22 additions & 6 deletions storage/s3.go
Original file line number Diff line number Diff line change
@@ -1235,9 +1235,9 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session.
WithLogger(sdkLogger{})
}

awsCfg.Retryer = newCustomRetryer(opts.MaxRetries)
awsCfg.Retryer = newCustomRetryer(sc, opts.MaxRetries)

useSharedConfig := session.SharedConfigEnable
useSharedConfig := session.SharedConfigDisable
{
// Reverse of what the SDK does: if AWS_SDK_LOAD_CONFIG is 0 (or a
// falsy value) disable shared configs
@@ -1276,7 +1276,7 @@ func (sc *SessionCache) newSession(ctx context.Context, opts Options) (*session.
return sess, nil
}

func (sc *SessionCache) clear() {
func (sc *SessionCache) Clear() {
sc.Lock()
defer sc.Unlock()
sc.sessions = map[Options]*session.Session{}
@@ -1324,10 +1324,12 @@ func setSessionRegion(ctx context.Context, sess *session.Session, bucket string)
// error codes. Such as, retry for S3 InternalError code.
type customRetryer struct {
client.DefaultRetryer
sc *SessionCache
}

func newCustomRetryer(maxRetries int) *customRetryer {
func newCustomRetryer(sc *SessionCache, maxRetries int) *customRetryer {
return &customRetryer{
sc: sc,
DefaultRetryer: client.DefaultRetryer{
NumMaxRetries: maxRetries,
},
@@ -1337,13 +1339,27 @@ func newCustomRetryer(maxRetries int) *customRetryer {
// ShouldRetry overrides SDK's built in DefaultRetryer, adding custom retry
// logics that are not included in the SDK.
func (c *customRetryer) ShouldRetry(req *request.Request) bool {
shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out")
log.Error(log.ErrorMessage{
Command: "retrier",
Err: req.Error.Error(),
})

shouldRetry := errHasCode(req.Error, "InternalError") || errHasCode(req.Error, "RequestTimeTooSkewed") || errHasCode(req.Error, "SlowDown") || strings.Contains(req.Error.Error(), "connection reset") || strings.Contains(req.Error.Error(), "connection timed out") || errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException")

if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") {
log.Debug(log.DebugMessage{
Err: "Clearing the token",
})

c.sc.Clear()
}

if !shouldRetry {
shouldRetry = c.DefaultRetryer.ShouldRetry(req)
}

// Errors related to tokens
if errHasCode(req.Error, "ExpiredToken") || errHasCode(req.Error, "ExpiredTokenException") || errHasCode(req.Error, "InvalidToken") {
if errHasCode(req.Error, "InvalidToken") {
return false
}

16 changes: 10 additions & 6 deletions storage/s3_test.go
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ func TestNewSessionPathStyle(t *testing.T) {
}

func TestNewSessionWithRegionSetViaEnv(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()

const expectedRegion = "us-west-2"

@@ -116,7 +116,7 @@ func TestNewSessionWithRegionSetViaEnv(t *testing.T) {
}

func TestNewSessionWithNoSignRequest(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()

sess, err := globalSessionCache.newSession(context.Background(), Options{
NoSignRequest: true,
@@ -190,7 +190,7 @@ aws_secret_access_key = p2_profile_access_key`
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()
sess, err := globalSessionCache.newSession(context.Background(), Options{
Profile: tc.profileName,
CredentialFile: tc.fileName,
@@ -538,8 +538,12 @@ func TestS3Retry(t *testing.T) {
for _, tc := range testcases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
sessionCache := &SessionCache{
sessions: map[Options]*session.Session{},
}

sess := unit.Session
sess.Config.Retryer = newCustomRetryer(expectedRetry)
sess.Config.Retryer = newCustomRetryer(sessionCache, expectedRetry)

mockAPI := s3.New(sess)
mockS3 := &S3{
@@ -1041,7 +1045,7 @@ func TestSessionRegionDetection(t *testing.T) {
opts.bucket = tc.bucket
}

globalSessionCache.clear()
globalSessionCache.Clear()

sess, err := globalSessionCache.newSession(context.Background(), opts)
if err != nil {
@@ -1241,7 +1245,7 @@ func TestAWSLogLevel(t *testing.T) {

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
globalSessionCache.clear()
globalSessionCache.Clear()
sess, err := globalSessionCache.newSession(context.Background(), Options{
LogLevel: log.LevelFromString(tc.level),
})

0 comments on commit 6706169

Please sign in to comment.