-
Notifications
You must be signed in to change notification settings - Fork 19
/
token.go
330 lines (274 loc) · 8.74 KB
/
token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
package jwt
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
)
var (
// ErrMissing indicates that a given token to `Verify` is empty.
ErrMissing = errors.New("jwt: token is empty")
// ErrTokenForm indicates that the extracted token has not the expected form .
ErrTokenForm = errors.New("jwt: invalid token form")
// ErrTokenAlg indicates that the given algorithm does not match the extracted one.
ErrTokenAlg = errors.New("jwt: unexpected token algorithm")
)
type (
// PrivateKey is a generic type, this key is responsible for signing the token.
PrivateKey = interface{}
// PublicKey is a generic type, this key is responsible to verify the token.
PublicKey = interface{}
)
func encodeToken(alg Alg, key PrivateKey, payload []byte, customHeader interface{}) ([]byte, error) {
var header []byte
if customHeader != nil {
h, err := createCustomHeader(customHeader)
if err != nil {
return nil, err
}
header = h
} else {
header = createHeader(alg.Name())
}
payload = Base64Encode(payload)
headerPayload := joinParts(header, payload)
signature, err := createSignature(alg, key, headerPayload)
if err != nil {
return nil, fmt.Errorf("encodeToken: signature: %w", err)
}
// header.payload.signature
token := joinParts(headerPayload, signature)
return token, nil
}
// We could omit the "alg" because the token contains it
// BUT, for security reason the algorithm MUST explicitly match
// (even if we perform hash comparison later on).
//
// If the "compareHeaderFunc" is nil then it compares using the `CompareHeader` package-level function variable.
//
// Decodes and verifies the given compact "token".
// It returns the header, payoad and signature parts (decoded).
func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderValidator) ([]byte, []byte, []byte, error) {
parts := bytes.Split(token, sep)
if len(parts) != 3 {
return nil, nil, nil, ErrTokenForm
}
header := parts[0]
payload := parts[1]
signature := parts[2]
headerDecoded, err := Base64Decode(header)
if err != nil {
return nil, nil, nil, err
}
// validate header equality.
if compareHeaderFunc == nil {
compareHeaderFunc = CompareHeader
}
// algorithm can be specified hard-coded
// or extracted per token if a custom header validator given.
algName := ""
if alg != nil {
algName = alg.Name()
}
dynamicAlg, pubKey, decrypt, err := compareHeaderFunc(algName, headerDecoded)
if err != nil {
return nil, nil, nil, err
}
if alg == nil {
alg = dynamicAlg
}
// Override the key given, which could be a nil if this "pubKey" always expected on success.
if pubKey != nil {
key = pubKey
}
signatureDecoded, err := Base64Decode(signature)
if err != nil {
return nil, nil, nil, err
}
// validate signature.
headerPayload := joinParts(header, payload)
if err := alg.Verify(key, headerPayload, signatureDecoded); err != nil {
return nil, nil, nil, err
}
payload, err = Base64Decode(payload)
if err != nil {
return nil, nil, nil, err
}
if decrypt != nil {
payload, err = decrypt(payload)
if err != nil {
return nil, nil, nil, err
}
}
return headerDecoded, payload, signatureDecoded, nil
}
var (
sep = []byte(".")
pad = []byte("=")
padStr = string(pad)
)
func joinParts(parts ...[]byte) []byte {
return bytes.Join(parts, sep)
}
// A builtin list of fixed headers for builtin algorithms (to boost the performance a bit).
// key = alg, value = the base64encoded full header
// (when kid or any other extra headers are not required to be inside).
type fixedHeader struct {
// the json raw byte value.
raw []byte
// the base64 encoded value of raw.
encoded []byte
// same as raw but reversed order, e.g. first type then alg.
// Useful to validate external jwt tokens that are not using the standard form order.
reversed []byte
}
var fixedHeaders = make(map[string]*fixedHeader, len(allAlgs))
func init() {
for _, alg := range allAlgs {
k := alg.Name()
fixedHeaders[k] = &fixedHeader{
raw: createHeaderRaw(k),
encoded: createHeader(k),
reversed: createHeaderReversed(k),
}
}
}
func createHeader(alg string) []byte {
if header := fixedHeaders[alg]; header != nil {
return header.encoded
}
return Base64Encode([]byte(`{"alg":"` + alg + `","typ":"JWT"}`))
}
func createCustomHeader(header interface{}) ([]byte, error) {
b, err := Marshal(header)
if err != nil {
return nil, err
}
return Base64Encode(b), nil
}
func createHeaderRaw(alg string) []byte {
if header := fixedHeaders[alg]; header != nil {
return header.raw
}
return []byte(`{"alg":"` + alg + `","typ":"JWT"}`)
}
func createHeaderReversed(alg string) []byte {
if header := fixedHeaders[alg]; header != nil {
return header.reversed
}
return []byte(`{"typ":"JWT","alg":"` + alg + `"}`)
}
func createHeaderWithoutTyp(alg string) []byte {
return []byte(`{"alg":"` + alg + `"}`)
}
// HeaderValidator is a function which can be used to customize how the header is validated,
// by default it makes sure the algorithm is the same as the "alg" field.
//
// If the "alg" is empty then this function should return a non-nil algorithm
// based on the token contents.
// It should return a nil PublicKey and a non-nil error on validation failure.
// The out InjectFunc is optional. If it's not nil then decryption of the payload
// using GCM (AES key) is performed before verification.
// On success, if public key is not nil then it overrides the VerifyXXX method's one.
type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error)
// Note that this check is fully hard coded for known
// algorithms and it is fully hard coded in terms of
// its serialized format.
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
if n := len(headerDecoded); n < 25 /* 28 but allow custom short algs*/ {
if n == 15 { // header without "typ": "JWT".
expectedHeader := createHeaderWithoutTyp(alg)
if bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil, nil
}
}
return nil, nil, nil, ErrTokenAlg
}
// Fast check if the order is reversed.
// The specification says otherwise but
// some other programming languages' libraries
// don't actually follow the correct order.
if headerDecoded[2] == 't' {
expectedHeader := createHeaderReversed(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil, ErrTokenAlg
}
return nil, nil, nil, nil
}
expectedHeader := createHeaderRaw(alg)
if !bytes.Equal(expectedHeader, headerDecoded) {
return nil, nil, nil, ErrTokenAlg
}
return nil, nil, nil, nil
}
func createSignature(alg Alg, key PrivateKey, headerAndPayload []byte) ([]byte, error) {
signature, err := alg.Sign(key, headerAndPayload)
if err != nil {
return nil, err
}
return Base64Encode(signature), nil
}
// Base64Encode encodes "src" to jwt base64 url format.
// We could use the base64.RawURLEncoding but the below is a bit faster.
func Base64Encode(src []byte) []byte {
buf := make([]byte, base64.URLEncoding.EncodedLen(len(src)))
base64.URLEncoding.Encode(buf, src)
return bytes.TrimRight(buf, padStr) // JWT: no trailing '='.
}
// Base64Decode decodes "src" to jwt base64 url format.
// We could use the base64.RawURLEncoding but the below is a bit faster.
func Base64Decode(src []byte) ([]byte, error) {
if n := len(src) % 4; n > 0 {
// JWT: Because of no trailing '=' let's suffix it
// with the correct number of those '=' before decoding.
src = append(src, bytes.Repeat(pad, 4-n)...)
}
buf := make([]byte, base64.URLEncoding.DecodedLen(len(src)))
n, err := base64.URLEncoding.Decode(buf, src)
return buf[:n], err
}
// Decode decodes the token of compact form WITHOUT verification and validation.
//
// This function is only useful to read a token's claims
// when the source is trusted and no algorithm verification or direct signature and
// content validation is required.
//
// Use `Verify/VerifyEncrypted` functions instead.
func Decode(token []byte) (*UnverifiedToken, error) {
parts := bytes.Split(token, sep)
if len(parts) != 3 {
return nil, ErrTokenForm
}
header := parts[0]
payload := parts[1]
signature := parts[2]
headerDecoded, err := Base64Decode(header)
if err != nil {
return nil, err
}
signatureDecoded, err := Base64Decode(signature)
if err != nil {
return nil, err
}
payload, err = Base64Decode(payload)
if err != nil {
return nil, err
}
tok := &UnverifiedToken{
Header: headerDecoded,
Payload: payload,
Signature: signatureDecoded,
}
return tok, nil
}
// UnverifiedToken contains the compact form token parts.
// Look its `Claims` method to decode to a custom structure.
type UnverifiedToken struct {
Header []byte
Payload []byte
Signature []byte
}
// Claims decodes the `Payload` field to the "dest".
func (t *UnverifiedToken) Claims(dest interface{}) error {
return Unmarshal(t.Payload, dest)
}