Skip to content

Commit

Permalink
refactor: merge shared codes
Browse files Browse the repository at this point in the history
Signed-off-by: huabing zhao <zhaohuabing@gmail.com>
  • Loading branch information
zhaohuabing committed Nov 15, 2023
1 parent 836b805 commit c606a93
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 111 deletions.
7 changes: 7 additions & 0 deletions internal/gatewayapi/securitypolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package gatewayapi
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"sort"
Expand Down Expand Up @@ -531,6 +532,12 @@ func validateTokenEndpoint(tokenEndpoint string) error {
return fmt.Errorf("token endpoint URL scheme must be https: %s", tokenEndpoint)
}

Check warning on line 533 in internal/gatewayapi/securitypolicy.go

View check run for this annotation

Codecov / codecov/patch

internal/gatewayapi/securitypolicy.go#L532-L533

Added lines #L532 - L533 were not covered by tests

if ip := net.ParseIP(parsedURL.Hostname()); ip != nil {
if v4 := ip.To4(); v4 != nil {
return fmt.Errorf("token endpoint URL must be a domain name: %s", tokenEndpoint)
}

Check warning on line 538 in internal/gatewayapi/securitypolicy.go

View check run for this annotation

Codecov / codecov/patch

internal/gatewayapi/securitypolicy.go#L536-L538

Added lines #L536 - L538 were not covered by tests
}

if parsedURL.Port() != "" {
_, err = strconv.Atoi(parsedURL.Port())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/xds/translator/accesslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func processClusterForAccessLog(tCtx *types.ResourceVersionTable, al *ir.AccessL
name: clusterName,
settings: []*ir.DestinationSetting{ds},
tSocket: nil,
endpointType: DefaultEndpointType,
endpointType: EndpointTypeDNS,
}); err != nil && !errors.Is(err, ErrXdsClusterExists) {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/xds/translator/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func buildXdsCluster(args *xdsClusterArgs) *clusterv3.Cluster {
cluster.TransportSocket = args.tSocket
}

if args.endpointType == Static {
if args.endpointType == EndpointTypeStatic {
cluster.ClusterDiscoveryType = &clusterv3.Cluster_Type{Type: clusterv3.Cluster_EDS}
cluster.EdsClusterConfig = &clusterv3.Cluster_EdsClusterConfig{
ServiceName: args.name,
Expand Down
2 changes: 1 addition & 1 deletion internal/xds/translator/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestBuildXdsCluster(t *testing.T) {
args := &xdsClusterArgs{
name: bootstrapXdsCluster.Name,
tSocket: bootstrapXdsCluster.TransportSocket,
endpointType: DefaultEndpointType,
endpointType: EndpointTypeDNS,
}
dynamicXdsCluster := buildXdsCluster(args)

Expand Down
67 changes: 3 additions & 64 deletions internal/xds/translator/jwt_authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ package translator
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
Expand All @@ -23,7 +19,6 @@ import (
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"

"github.com/envoyproxy/gateway/api/v1alpha1"
"github.com/envoyproxy/gateway/internal/ir"
"github.com/envoyproxy/gateway/internal/utils/ptr"
"github.com/envoyproxy/gateway/internal/xds/types"
Expand Down Expand Up @@ -102,7 +97,7 @@ func buildJWTAuthn(irListener *ir.HTTPListener) (*jwtauthnv3.JwtAuthentication,
for i := range route.JWT.Providers {
irProvider := route.JWT.Providers[i]
// Create the cluster for the remote jwks, if it doesn't exist.
jwksCluster, err := newJWKSCluster(&irProvider)
jwksCluster, err := url2Cluster(irProvider.RemoteJWKS.URI)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -228,13 +223,6 @@ func patchRouteWithJWT(route *routev3.Route, irRoute *ir.HTTPRoute) error {
return nil
}

type jwksCluster struct {
name string
hostname string
port uint32
isStatic bool
}

// createJWKSClusters creates JWKS clusters from the provided routes, if needed.
func createJWKSClusters(tCtx *types.ResourceVersionTable, routes []*ir.HTTPRoute) error {
if tCtx == nil ||
Expand All @@ -248,11 +236,7 @@ func createJWKSClusters(tCtx *types.ResourceVersionTable, routes []*ir.HTTPRoute
if routeContainsJWTAuthn(route) {
for i := range route.JWT.Providers {
provider := route.JWT.Providers[i]
jwks, err := newJWKSCluster(&provider)
epType := DefaultEndpointType
if jwks.isStatic {
epType = Static
}
jwks, err := url2Cluster(provider.RemoteJWKS.URI)
if err != nil {
return err
}
Expand All @@ -268,7 +252,7 @@ func createJWKSClusters(tCtx *types.ResourceVersionTable, routes []*ir.HTTPRoute
name: jwks.name,
settings: []*ir.DestinationSetting{ds},
tSocket: tSocket,
endpointType: epType,
endpointType: jwks.endpointType,
}); err != nil && !errors.Is(err, ErrXdsClusterExists) {
return err
}
Expand All @@ -279,51 +263,6 @@ func createJWKSClusters(tCtx *types.ResourceVersionTable, routes []*ir.HTTPRoute
return nil
}

// newJWKSCluster returns a jwksCluster from the provided provider.
func newJWKSCluster(provider *v1alpha1.JWTProvider) (*jwksCluster, error) {
static := false
if provider == nil {
return nil, errors.New("nil provider")
}

u, err := url.Parse(provider.RemoteJWKS.URI)
if err != nil {
return nil, err
}

var strPort string
switch u.Scheme {
case "https":
strPort = "443"
default:
return nil, fmt.Errorf("unsupported JWKS URI scheme %s", u.Scheme)
}

if u.Port() != "" {
strPort = u.Port()
}

name := fmt.Sprintf("%s_%s", strings.ReplaceAll(u.Hostname(), ".", "_"), strPort)

port, err := strconv.Atoi(strPort)
if err != nil {
return nil, err
}

if ip := net.ParseIP(u.Hostname()); ip != nil {
if v4 := ip.To4(); v4 != nil {
static = true
}
}

return &jwksCluster{
name: name,
hostname: u.Hostname(),
port: uint32(port),
isStatic: static,
}, nil
}

// listenerContainsJWTAuthn returns true if JWT authentication exists for the
// provided listener.
func listenerContainsJWTAuthn(irListener *ir.HTTPListener) bool {
Expand Down
59 changes: 35 additions & 24 deletions internal/xds/translator/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"crypto/rand"
"errors"
"fmt"
"net/url"
"strconv"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
Expand All @@ -30,7 +28,6 @@ import (

const (
oauth2Filter = "envoy.filters.http.oauth2"
defaultTokenEndpointPort = 443
defaultTokenEndpointTimeout = 10
redirectURL = "%REQ(x-forwarded-proto)%://%REQ(:authority)%/oauth2/callback"
redirectPathMatcher = "/oauth2/callback"
Expand Down Expand Up @@ -73,7 +70,10 @@ func patchHCMWithOAuth2Filter(mgr *hcmv3.HttpConnectionManager, irListener *ir.H

// buildHCMOAuth2Filter returns an OAuth2 HTTP filter from the provided IR HTTPRoute.
func buildHCMOAuth2Filter(route *ir.HTTPRoute) (*hcmv3.HttpFilter, error) {
oauth2Proto := oauth2Config(route)
oauth2Proto, err := oauth2Config(route)
if err != nil {
return nil, err
}

Check warning on line 76 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L75-L76

Added lines #L75 - L76 were not covered by tests

if err := oauth2Proto.ValidateAll(); err != nil {
return nil, err
Expand All @@ -96,17 +96,23 @@ func oauth2FilterName(route *ir.HTTPRoute) string {
return fmt.Sprintf("%s_%s", oauth2Filter, route.Name)
}

func oauth2Config(route *ir.HTTPRoute) *oauth2v3.OAuth2 {
// Ignore the errors because we already validate the token endpoint
// URL in the gateway API translator.
tokenEndpointURL, _ := url.Parse(route.OIDC.Provider.TokenEndpoint)
func oauth2Config(route *ir.HTTPRoute) (*oauth2v3.OAuth2, error) {
cluster, err := url2Cluster(route.OIDC.Provider.TokenEndpoint)
if err != nil {
return nil, err
}

Check warning on line 103 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L102-L103

Added lines #L102 - L103 were not covered by tests
if cluster.endpointType == EndpointTypeStatic {
return nil, fmt.Errorf(
"static IP cluster is not allowed: %s",
route.OIDC.Provider.TokenEndpoint)
}

Check warning on line 108 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L105-L108

Added lines #L105 - L108 were not covered by tests

oauth2 := &oauth2v3.OAuth2{
Config: &oauth2v3.OAuth2Config{
TokenEndpoint: &corev3.HttpUri{
Uri: route.OIDC.Provider.TokenEndpoint,
HttpUpstreamType: &corev3.HttpUri_Cluster{
Cluster: oauth2TokenEndpointClusterName(tokenEndpointURL),
Cluster: cluster.name,
},
Timeout: &duration.Duration{
Seconds: defaultTokenEndpointTimeout,
Expand Down Expand Up @@ -150,7 +156,7 @@ func oauth2Config(route *ir.HTTPRoute) *oauth2v3.OAuth2 {
AuthScopes: route.OIDC.Scopes,
},
}
return oauth2
return oauth2, nil
}

// routeContainsOIDC returns true if OIDC exists for the provided route.
Expand Down Expand Up @@ -181,16 +187,20 @@ func createOAuth2TokenEndpointClusters(tCtx *types.ResourceVersionTable, routes
continue
}

// Ignore the errors because we already validate the token endpoint
// URL in the gateway API translator.
tokenEndpointURL, _ := url.Parse(route.OIDC.Provider.TokenEndpoint)
port := defaultTokenEndpointPort
if tokenEndpointURL.Port() != "" {
port, _ = strconv.Atoi(tokenEndpointURL.Port())
cluster, err := url2Cluster(route.OIDC.Provider.TokenEndpoint)
if err != nil {
return err
}

Check warning on line 193 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L192-L193

Added lines #L192 - L193 were not covered by tests

// EG does not support static IP clusters for token endpoint clusters.
if cluster.endpointType == EndpointTypeStatic {
return fmt.Errorf(
"static IP cluster is not allowed: %s",
route.OIDC.Provider.TokenEndpoint)
}

Check warning on line 200 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L197-L200

Added lines #L197 - L200 were not covered by tests

tlsContext := &tlsv3.UpstreamTlsContext{
Sni: tokenEndpointURL.Hostname(),
Sni: cluster.hostname,
}

tlsContextAny, err := anypb.New(tlsContext)
Expand All @@ -207,15 +217,16 @@ func createOAuth2TokenEndpointClusters(tCtx *types.ResourceVersionTable, routes
ds := &ir.DestinationSetting{
Weight: ptr.To(uint32(1)),
Endpoints: []*ir.DestinationEndpoint{ir.NewDestEndpoint(
tokenEndpointURL.Hostname(),
uint32(port))},
cluster.hostname,
cluster.port),
},
}

if err := addXdsCluster(tCtx, &xdsClusterArgs{
name: oauth2TokenEndpointClusterName(tokenEndpointURL),
name: cluster.name,
settings: []*ir.DestinationSetting{ds},
tSocket: tSocket,
endpointType: DefaultEndpointType, // TODO support static endpoint
endpointType: cluster.endpointType,
}); err != nil && !errors.Is(err, ErrXdsClusterExists) {
return err
}

Check warning on line 232 in internal/xds/translator/oidc.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/oidc.go#L231-L232

Added lines #L231 - L232 were not covered by tests
Expand All @@ -224,9 +235,9 @@ func createOAuth2TokenEndpointClusters(tCtx *types.ResourceVersionTable, routes
return nil
}

func oauth2TokenEndpointClusterName(tokenEndpointURL *url.URL) string {
return fmt.Sprintf("oauth2_token_endpoint_%s", tokenEndpointURL.Hostname())
}
/*func oauth2TokenEndpointClusterName(cluster *Cluster) string {
return fmt.Sprintf("oauth2_token_endpoint_%s", cluster.hostname)
}*/

// createOAuth2Secrets creates OAuth2 client and HMAC secrets from the provided
// routes, if needed.
Expand Down
3 changes: 1 addition & 2 deletions internal/xds/translator/ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package translator
import (
"bytes"
"errors"

"net/url"
"strconv"
"strings"
Expand Down Expand Up @@ -446,7 +445,7 @@ func (t *Translator) createRateLimitServiceCluster(tCtx *types.ResourceVersionTa
name: clusterName,
settings: []*ir.DestinationSetting{ds},
tSocket: tSocket,
endpointType: DefaultEndpointType,
endpointType: EndpointTypeDNS,
}); err != nil && !errors.Is(err, ErrXdsClusterExists) {
return err
}
Expand Down
64 changes: 64 additions & 0 deletions internal/xds/translator/shared_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright Envoy Gateway Authors
// SPDX-License-Identifier: Apache-2.0
// The full text of the Apache license is available in the LICENSE file at
// the root of the repo.

package translator

import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
)

const (
defaultPort = 443
)

// urlCluster is a cluster that is created from a URL.
type urlCluster struct {
name string
hostname string
port uint32
endpointType EndpointType
}

// url2Cluster returns a urlCluster from the provided url.
func url2Cluster(strURL string) (*urlCluster, error) {
epType := EndpointTypeDNS

// The URL should have already been validated in the gateway API translator.
u, err := url.Parse(strURL)
if err != nil {
return nil, err
}

Check warning on line 36 in internal/xds/translator/shared_types.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/shared_types.go#L35-L36

Added lines #L35 - L36 were not covered by tests

if u.Scheme != "https" {
return nil, fmt.Errorf("unsupported URI scheme %s", u.Scheme)
}

Check warning on line 40 in internal/xds/translator/shared_types.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/shared_types.go#L39-L40

Added lines #L39 - L40 were not covered by tests

port := defaultPort
if u.Port() != "" {
port, err = strconv.Atoi(u.Port())
if err != nil {
return nil, err
}

Check warning on line 47 in internal/xds/translator/shared_types.go

View check run for this annotation

Codecov / codecov/patch

internal/xds/translator/shared_types.go#L46-L47

Added lines #L46 - L47 were not covered by tests
}

name := fmt.Sprintf("%s_%d", strings.ReplaceAll(u.Hostname(), ".", "_"), port)

if ip := net.ParseIP(u.Hostname()); ip != nil {
if v4 := ip.To4(); v4 != nil {
epType = EndpointTypeStatic
}
}

return &urlCluster{
name: name,
hostname: u.Hostname(),
port: uint32(port),
endpointType: epType,
}, nil
}
Loading

0 comments on commit c606a93

Please sign in to comment.