From 0eafd3f2cdf27063899a2ac99928bd57233587a4 Mon Sep 17 00:00:00 2001 From: Yaroslav Zborovsky Date: Wed, 7 Feb 2024 14:26:27 +0200 Subject: [PATCH] Fix issue with id_token unmarshalling --- providers/apple/session.go | 43 ++++++++++++++++++++++++++---- providers/apple/session_test.go | 46 ++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/providers/apple/session.go b/providers/apple/session.go index 8f03b5158..6d239c341 100644 --- a/providers/apple/session.go +++ b/providers/apple/session.go @@ -48,10 +48,10 @@ func (s Session) Marshal() string { type IDTokenClaims struct { jwt.StandardClaims - AccessTokenHash string `json:"at_hash"` - AuthTime int `json:"auth_time"` - Email string `json:"email"` - IsPrivateEmail bool `json:"is_private_email,string"` + AccessTokenHash string `json:"at_hash"` + AuthTime int `json:"auth_time"` + Email string `json:"email"` + IsPrivateEmail BoolString `json:"is_private_email"` } func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { @@ -123,7 +123,7 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, s.ID = ID{ Sub: idToken.Claims.(*IDTokenClaims).Subject, Email: idToken.Claims.(*IDTokenClaims).Email, - IsPrivateEmail: idToken.Claims.(*IDTokenClaims).IsPrivateEmail, + IsPrivateEmail: idToken.Claims.(*IDTokenClaims).IsPrivateEmail.Value(), } } @@ -133,3 +133,36 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, func (s Session) String() string { return s.Marshal() } + +// BoolString is a type that can be unmarshalled from a JSON field that can be either a boolean or a string. +// It is used to unmarshal some fields in the Apple ID token that can be sent as either boolean or string. +// See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/authenticating_users_with_sign_in_with_apple#3383773 +type BoolString struct { + BoolValue bool + StringValue string + IsValidBool bool +} + +func (bs *BoolString) UnmarshalJSON(data []byte) error { + var b bool + if err := json.Unmarshal(data, &b); err == nil { + bs.BoolValue = b + bs.IsValidBool = true + return nil + } + + var s string + if err := json.Unmarshal(data, &s); err == nil { + bs.StringValue = s + return nil + } + + return errors.New("json field can be either boolean or string") +} + +func (bs *BoolString) Value() bool { + if bs.IsValidBool { + return bs.BoolValue + } + return bs.StringValue == "true" +} diff --git a/providers/apple/session_test.go b/providers/apple/session_test.go index 7d0aa437e..4516dcdd8 100644 --- a/providers/apple/session_test.go +++ b/providers/apple/session_test.go @@ -1,10 +1,12 @@ package apple import ( + "encoding/json" "testing" - "github.com/markbates/goth" "github.com/stretchr/testify/assert" + + "github.com/markbates/goth" ) func Test_Implements_Session(t *testing.T) { @@ -45,3 +47,45 @@ func Test_String(t *testing.T) { a.Equal(s.String(), s.Marshal()) } + +func TestIDTokenClaimsUnmarshal(t *testing.T) { + t.Parallel() + a := assert.New(t) + + cases := []struct { + name string + idToken string + expectedClaims IDTokenClaims + }{ + { + name: "'is_private_email' claim is a string", + idToken: `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","sub":"","email":"test-email@privaterelay.appleid.com","is_private_email":"true"}`, + expectedClaims: IDTokenClaims{ + Email: "test-email@privaterelay.appleid.com", + IsPrivateEmail: BoolString{ + StringValue: "true", + }, + }, + }, + { + name: "'is_private_email' claim is a boolean", + idToken: `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","sub":"","email":"test-email@privaterelay.appleid.com","is_private_email":true}`, + expectedClaims: IDTokenClaims{ + Email: "test-email@privaterelay.appleid.com", + IsPrivateEmail: BoolString{ + BoolValue: true, + IsValidBool: true, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + idTokenClaims := IDTokenClaims{} + err := json.Unmarshal([]byte(c.idToken), &idTokenClaims) + a.NoError(err) + a.Equal(idTokenClaims, c.expectedClaims) + }) + } +}