Skip to content

Commit

Permalink
Add support for extra audiences in JWT SVID (#213)
Browse files Browse the repository at this point in the history
...

Signed-off-by: Emiliano Spinella <emilianofs@gmail.com>
  • Loading branch information
eminwux authored Nov 12, 2024
1 parent 093b63f commit f4872f1
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 15 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The configuration file is an [HCL](https://github.com/hashicorp/hcl) formatted f
| `svid_file_name` | File name to be used to store the X.509 SVID public certificate in PEM format. | `"svid.pem"` |
| `svid_key_file_name` | File name to be used to store the X.509 SVID private key and public certificate in PEM format. | `"svid_key.pem"` |
| `svid_bundle_file_name` | File name to be used to store the X.509 SVID Bundle in PEM format. | `"svid_bundle.pem"` |
| `jwt_svids` | An array with the audience and file name to store the JWT SVIDs. File is Base64-encoded string). | `[{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]` |
| `jwt_svids` | An array with the audience, optional extra audiences array, and file name to store the JWT SVIDs. File is Base64-encoded string). | `[{jwt_audience="your-audience", jwt_extra_audiences=["your-extra-audience-1", "your-extra-audience-2"], jwt_svid_file_name="jwt_svid.token"}]` |
| `jwt_bundle_file_name` | File name to be used to store JWT Bundle in JSON format. | `"jwt_bundle.json"` |
| `include_federated_domains` | Include trust domains from federated servers in the CA bundle. | `true` |
| `cert_file_mode` | The octal file mode to use when saving the X.509 public certificate file. | `0644` |
Expand All @@ -48,7 +48,7 @@ renew_signal = "SIGUSR1"
svid_file_name = "svid.pem"
svid_key_file_name = "svid_key.pem"
svid_bundle_file_name = "svid_bundle.pem"
jwt_svids = [{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]
jwt_svids = [{jwt_audience="your-audience",jwt_extra_audiences=["your-extra-audience-1", "your-extra-audience-2"], jwt_svid_file_name="jwt_svid.token"}]
jwt_bundle_file_name = "bundle.json"
cert_file_mode = 0444
key_file_mode = 0444
Expand All @@ -63,6 +63,6 @@ cert_dir = "certs"
svid_file_name = "svid.pem"
svid_key_file_name = "svid_key.pem"
svid_bundle_file_name = "svid_bundle.pem"
jwt_svids = [{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]
jwt_svids = [{jwt_audience="your-audience",jwt_extra_audiences=["your-extra-audience-1", "your-extra-audience-2"], jwt_svid_file_name="jwt_svid.token"}]
jwt_bundle_file_name = "bundle.json"
```
10 changes: 6 additions & 4 deletions cmd/spiffe-helper/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ type Config struct {
}

type JWTConfig struct {
JWTAudience string `hcl:"jwt_audience"`
JWTSVIDFilename string `hcl:"jwt_svid_file_name"`
JWTAudience string `hcl:"jwt_audience"`
JWTExtraAudiences []string `hcl:"jwt_extra_audiences"`
JWTSVIDFilename string `hcl:"jwt_svid_file_name"`

UnusedKeyPositions map[string][]token.Pos `hcl:",unusedKeyPositions"`
}
Expand Down Expand Up @@ -188,8 +189,9 @@ func NewSidecarConfig(config *Config, log logrus.FieldLogger) *sidecar.Config {

for _, jwtSVID := range config.JWTSVIDs {
sidecarConfig.JWTSVIDs = append(sidecarConfig.JWTSVIDs, sidecar.JWTConfig{
JWTAudience: jwtSVID.JWTAudience,
JWTSVIDFilename: jwtSVID.JWTSVIDFilename,
JWTAudience: jwtSVID.JWTAudience,
JWTExtraAudiences: jwtSVID.JWTExtraAudiences,
JWTSVIDFilename: jwtSVID.JWTSVIDFilename,
})
}

Expand Down
2 changes: 2 additions & 0 deletions cmd/spiffe-helper/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestParseConfig(t *testing.T) {
expectedJWTSVIDFileName := "jwt_svid.token"
expectedJWTBundleFileName := "jwt_bundle.json"
expectedJWTAudience := "your-audience"
expectedJWTExtraAudiences := []string{"your-extra-audience-1", "your-extra-audience-2"}

assert.Equal(t, expectedAgentAddress, c.AgentAddress)
assert.Equal(t, expectedCmd, c.Cmd)
Expand All @@ -41,6 +42,7 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSVIDs[0].JWTSVIDFilename)
assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename)
assert.Equal(t, expectedJWTAudience, c.JWTSVIDs[0].JWTAudience)
assert.Equal(t, expectedJWTExtraAudiences, c.JWTSVIDs[0].JWTExtraAudiences)
assert.True(t, c.AddIntermediatesToBundle)
assert.Equal(t, 444, c.CertFileMode)
assert.Equal(t, 444, c.KeyFileMode)
Expand Down
2 changes: 2 additions & 0 deletions cmd/spiffe-helper/config/testdata/helper.conf
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ jwt_svids = [
{
jwt_svid_file_name = "jwt_svid.token"
jwt_audience = "your-audience"
jwt_extra_audiences = ["your-extra-audience-1", "your-extra-audience-2"]

}
]
timeout = "10s"
Expand Down
3 changes: 3 additions & 0 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ type JWTConfig struct {
// The audience for the JWT SVID to fetch
JWTAudience string

// The extra audiences for the JWT SVID to fetch
JWTExtraAudiences []string

// The filename to save the JWT SVID to
JWTSVIDFilename string
}
16 changes: 8 additions & 8 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error {
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSVIDFilename)
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTExtraAudiences, jwtConfig.JWTSVIDFilename)
}()
}
}
Expand Down Expand Up @@ -253,8 +253,8 @@ func (s *Sidecar) checkProcessExit() {
atomic.StoreInt32(&s.processRunning, 0)
}

func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string, jwtExtraAudiences []string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience, ExtraAudiences: jwtExtraAudiences})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
Expand Down Expand Up @@ -291,10 +291,10 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSVIDFilename string) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtExtraAudiences []string, jwtSVIDFilename string) (*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience)
jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience, jwtExtraAudiences)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
Expand All @@ -309,10 +309,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string,
return jwtSVID, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSVIDFilename string) {
func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtExtraAudiences []string, jwtSVIDFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSVIDFilename)
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
Expand All @@ -328,7 +328,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSVID
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSVIDFilename)
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtExtraAudiences, jwtSVIDFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
Expand Down

0 comments on commit f4872f1

Please sign in to comment.