forked from canonical/fosite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client_authentication_jwks_strategy.go
135 lines (111 loc) · 4.65 KB
/
client_authentication_jwks_strategy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package fosite
import (
"context"
"encoding/json"
"time"
"github.com/dgraph-io/ristretto"
"github.com/hashicorp/go-retryablehttp"
"github.com/ory/x/errorsx"
"github.com/go-jose/go-jose/v3"
)
const defaultJWKSFetcherStrategyCachePrefix = "github.com/ory/fosite.DefaultJWKSFetcherStrategy:"
// JWKSFetcherStrategy is a strategy which pulls (optionally caches) JSON Web Key Sets from a location,
// typically a client's jwks_uri.
type JWKSFetcherStrategy interface {
// Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces
// the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy
// to fetch the key.
Resolve(ctx context.Context, location string, ignoreCache bool) (*jose.JSONWebKeySet, error)
}
// DefaultJWKSFetcherStrategy is a default implementation of the JWKSFetcherStrategy interface.
type DefaultJWKSFetcherStrategy struct {
client *retryablehttp.Client
cache *ristretto.Cache[string, *jose.JSONWebKeySet]
ttl time.Duration
clientSourceFunc func(ctx context.Context) *retryablehttp.Client
}
// NewDefaultJWKSFetcherStrategy returns a new instance of the DefaultJWKSFetcherStrategy.
func NewDefaultJWKSFetcherStrategy(opts ...func(*DefaultJWKSFetcherStrategy)) JWKSFetcherStrategy {
dc, err := ristretto.NewCache(&ristretto.Config[string, *jose.JSONWebKeySet]{
NumCounters: 10000 * 10,
MaxCost: 10000,
BufferItems: 64,
Metrics: false,
Cost: func(value *jose.JSONWebKeySet) int64 {
return 1
},
})
if err != nil {
panic(err)
}
s := &DefaultJWKSFetcherStrategy{
cache: dc,
client: retryablehttp.NewClient(),
ttl: time.Hour,
}
for _, o := range opts {
o(s)
}
return s
}
// JKWKSFetcherWithDefaultTTL sets the default TTL for the cache.
func JKWKSFetcherWithDefaultTTL(ttl time.Duration) func(*DefaultJWKSFetcherStrategy) {
return func(s *DefaultJWKSFetcherStrategy) {
s.ttl = ttl
}
}
// JWKSFetcherWithCache sets the cache to use.
func JWKSFetcherWithCache(cache *ristretto.Cache[string, *jose.JSONWebKeySet]) func(*DefaultJWKSFetcherStrategy) {
return func(s *DefaultJWKSFetcherStrategy) {
s.cache = cache
}
}
// JWKSFetcherWithHTTPClient sets the HTTP client to use.
func JWKSFetcherWithHTTPClient(client *retryablehttp.Client) func(*DefaultJWKSFetcherStrategy) {
return func(s *DefaultJWKSFetcherStrategy) {
s.client = client
}
}
// JWKSFetcherWithHTTPClientSource sets the HTTP client source function to use.
func JWKSFetcherWithHTTPClientSource(clientSourceFunc func(ctx context.Context) *retryablehttp.Client) func(*DefaultJWKSFetcherStrategy) {
return func(s *DefaultJWKSFetcherStrategy) {
s.clientSourceFunc = clientSourceFunc
}
}
// Resolve returns the JSON Web Key Set, or an error if something went wrong. The forceRefresh, if true, forces
// the strategy to fetch the key from the remote. If forceRefresh is false, the strategy may use a caching strategy
// to fetch the key.
func (s *DefaultJWKSFetcherStrategy) Resolve(ctx context.Context, location string, ignoreCache bool) (*jose.JSONWebKeySet, error) {
cacheKey := defaultJWKSFetcherStrategyCachePrefix + location
key, ok := s.cache.Get(cacheKey)
if !ok || ignoreCache {
req, err := retryablehttp.NewRequest("GET", location, nil)
if err != nil {
return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to create HTTP 'GET' request to fetch JSON Web Keys from location '%s'.", location).WithWrap(err).WithDebug(err.Error()))
}
hc := s.client
if s.clientSourceFunc != nil {
hc = s.clientSourceFunc(ctx)
}
response, err := hc.Do(req.WithContext(ctx))
if err != nil {
return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to fetch JSON Web Keys from location '%s'. Check for typos or other network issues.", location).WithWrap(err).WithDebug(err.Error()))
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 400 {
return nil, errorsx.WithStack(ErrServerError.WithHintf("Expected successful status code in range of 200 - 399 from location '%s' but received code %d.", location, response.StatusCode))
}
var set jose.JSONWebKeySet
if err := json.NewDecoder(response.Body).Decode(&set); err != nil {
return nil, errorsx.WithStack(ErrServerError.WithHintf("Unable to decode JSON Web Keys from location '%s'. Please check for typos and if the URL returns valid JSON.", location).WithWrap(err).WithDebug(err.Error()))
}
_ = s.cache.SetWithTTL(cacheKey, &set, 1, s.ttl)
return &set, nil
}
return key, nil
}
func (s *DefaultJWKSFetcherStrategy) WaitForCache() {
s.cache.Wait()
}