From c1658b36960a657e8f806b6633408e8fd95b6fd8 Mon Sep 17 00:00:00 2001 From: lestrrat <49281+lestrrat@users.noreply.github.com> Date: Sat, 26 Oct 2024 09:40:49 +0900 Subject: [PATCH] allow passing context.Contex to jwe.Decrypt (#1222) * allow passing context.Contex to jwe.Decrypt * really, this is OK --------- Co-authored-by: Daisuke Maki --- jwe/jwe.go | 6 ++++-- jwe/options.yaml | 6 ++++++ jwe/options_gen.go | 12 ++++++++++++ jwe/options_gen_test.go | 1 + tools/cmd/genoptions/main.go | 2 +- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/jwe/jwe.go b/jwe/jwe.go index 38ac60399..122a148f6 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -542,6 +542,7 @@ func decrypt(buf []byte, options ...DecryptOption) ([]byte, error) { var cek *[]byte var dst *Message perCallMaxDecompressBufferSize := maxDecompressBufferSize + ctx := context.Background() //nolint:forcetypeassert for _, option := range options { switch option.Ident() { @@ -565,6 +566,9 @@ func decrypt(buf []byte, options ...DecryptOption) ([]byte, error) { cek = option.Value().(*[]byte) case identMaxDecompressBufferSize{}: perCallMaxDecompressBufferSize = option.Value().(int64) + case identContext{}: + //nolint:fatcontext + ctx = option.Value().(context.Context) } } @@ -625,8 +629,6 @@ func decrypt(buf []byte, options ...DecryptOption) ([]byte, error) { dctx.cek = cek dctx.maxDecompressBufferSize = perCallMaxDecompressBufferSize - ctx := context.TODO() - errs := make([]error, 0, len(recipients)) for _, recipient := range recipients { decrypted, err := dctx.try(ctx, recipient, keyUsed) diff --git a/jwe/options.yaml b/jwe/options.yaml index 1e54b1f3f..2d84793ad 100644 --- a/jwe/options.yaml +++ b/jwe/options.yaml @@ -52,6 +52,12 @@ options: - ident: KeyProvider interface: DecryptOption argument_type: KeyProvider + - ident: Context + interface: DecryptOption + argument_type: context.Context + comment: | + WithContext specifies the context.Context object to use when decrypting a JWE message. + If not provided, context.Background() will be used. - ident: Serialization option_name: WithCompact interface: EncryptOption diff --git a/jwe/options_gen.go b/jwe/options_gen.go index 98ba1c7a6..b92b76fed 100644 --- a/jwe/options_gen.go +++ b/jwe/options_gen.go @@ -3,6 +3,7 @@ package jwe import ( + "context" "io/fs" "github.com/lestrrat-go/jwx/v3/jwa" @@ -141,6 +142,7 @@ type identCBCBufferSize struct{} type identCEK struct{} type identCompress struct{} type identContentEncryptionAlgorithm struct{} +type identContext struct{} type identFS struct{} type identKey struct{} type identKeyProvider struct{} @@ -171,6 +173,10 @@ func (identContentEncryptionAlgorithm) String() string { return "WithContentEncryption" } +func (identContext) String() string { + return "WithContext" +} + func (identFS) String() string { return "WithFS" } @@ -257,6 +263,12 @@ func WithContentEncryption(v jwa.ContentEncryptionAlgorithm) EncryptOption { return &encryptOption{option.New(identContentEncryptionAlgorithm{}, v)} } +// WithContext specifies the context.Context object to use when decrypting a JWE message. +// If not provided, context.Background() will be used. +func WithContext(v context.Context) DecryptOption { + return &decryptOption{option.New(identContext{}, v)} +} + // WithFS specifies the source `fs.FS` object to read the file from. func WithFS(v fs.FS) ReadFileOption { return &readFileOption{option.New(identFS{}, v)} diff --git a/jwe/options_gen_test.go b/jwe/options_gen_test.go index 7af86a2ec..d0de70900 100644 --- a/jwe/options_gen_test.go +++ b/jwe/options_gen_test.go @@ -13,6 +13,7 @@ func TestOptionIdent(t *testing.T) { require.Equal(t, "WithCEK", identCEK{}.String()) require.Equal(t, "WithCompress", identCompress{}.String()) require.Equal(t, "WithContentEncryption", identContentEncryptionAlgorithm{}.String()) + require.Equal(t, "WithContext", identContext{}.String()) require.Equal(t, "WithFS", identFS{}.String()) require.Equal(t, "WithKey", identKey{}.String()) require.Equal(t, "WithKeyProvider", identKeyProvider{}.String()) diff --git a/tools/cmd/genoptions/main.go b/tools/cmd/genoptions/main.go index 2035906be..11b57ccc7 100644 --- a/tools/cmd/genoptions/main.go +++ b/tools/cmd/genoptions/main.go @@ -119,7 +119,7 @@ func _main() error { }) if err := genOptions(&objects); err != nil { - return fmt.Errorf(`failed to generate %q`, objects.Output) + return fmt.Errorf(`failed to generate %q: %w`, objects.Output, err) } if err := genOptionTests(&objects); err != nil {