diff --git a/go.mod b/go.mod index e9e59a9..23b08a7 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,12 @@ require ( github.com/antihax/optional v1.0.0 github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/bronze1man/radius v0.0.0-20190516032554-afd8baec892d - github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6 + github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693 github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808 github.com/gin-gonic/gin v1.9.1 github.com/google/gopacket v1.1.19 github.com/google/uuid v1.3.0 + github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.8.3 github.com/urfave/cli v1.22.5 @@ -39,7 +40,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.0.1 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect diff --git a/go.sum b/go.sum index 1b3bda3..82fdce1 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= -github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6 h1:8P/wOkTAQMgZJe9pUUNSTE5PWeAdlMrsU9kLsI+VAVE= -github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA= +github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693 h1:gFyYBsErQAkx4OVHXYqjO0efO9gPWydQavQcjU0CkHY= +github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA= github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808 h1:8/IoWEgcO2DLlLCqbsxwduD7CzXdKe/BFJU2tcAqnxo= github.com/free5gc/util v1.0.5-0.20231205080047-308f623d6808/go.mod h1:d+79g84a3YHhzvjJ2IhurrBOavOA8xWIQ/GCywPXqQk= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= diff --git a/internal/context/context.go b/internal/context/context.go index 53c3b03..017ac74 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -104,6 +104,12 @@ func Init() { InitAusfContext(&ausfContext) } +type NFContext interface { + AuthorizationCheck(token string, serviceName models.ServiceName) error +} + +var _ NFContext = &AUSFContext{} + func NewAusfUeContext(identifier string) (ausfUeContext *AusfUeContext) { ausfUeContext = new(AusfUeContext) ausfUeContext.Supi = identifier // supi @@ -160,12 +166,22 @@ func (a *AUSFContext) GetSelfID() string { return a.NfId } -func (c *AUSFContext) GetTokenCtx(scope, targetNF string) ( +func (c *AUSFContext) GetTokenCtx(serviceName models.ServiceName, targetNF models.NfType) ( context.Context, *models.ProblemDetails, error, ) { if !c.OAuth2Required { return context.TODO(), nil, nil } - return oauth.GetTokenCtx(models.NfType_AUSF, - c.NfId, c.NrfUri, scope, targetNF) + return oauth.GetTokenCtx(models.NfType_AUSF, targetNF, + c.NfId, c.NrfUri, string(serviceName)) +} + +func (c *AUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error { + if !c.OAuth2Required { + logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: OAuth2 not required\n") + return nil + } + + logger.UtilLog.Debugf("AUSFContext::AuthorizationCheck: token[%s] serviceName[%s]\n", token, serviceName) + return oauth.VerifyOAuth(token, string(serviceName), c.NrfCertPem) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index c715cad..90993e0 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -18,6 +18,7 @@ var ( UeAuthLog *logrus.Entry Auth5gAkaLog *logrus.Entry AuthELog *logrus.Entry + UtilLog *logrus.Entry ) func init() { @@ -37,4 +38,5 @@ func init() { UeAuthLog = NfLog.WithField(logger_util.FieldCategory, "UeAuth") Auth5gAkaLog = NfLog.WithField(logger_util.FieldCategory, "5gAka") AuthELog = NfLog.WithField(logger_util.FieldCategory, "Eap") + UtilLog = NfLog.WithField(logger_util.FieldCategory, "Util") } diff --git a/internal/sbi/consumer/nf_discovery.go b/internal/sbi/consumer/nf_discovery.go index 6278fde..3c97013 100644 --- a/internal/sbi/consumer/nf_discovery.go +++ b/internal/sbi/consumer/nf_discovery.go @@ -13,7 +13,7 @@ import ( func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfType, param Nnrf_NFDiscovery.SearchNFInstancesParamOpts, ) (*models.SearchResult, error) { - ctx, _, err := ausf_context.GetSelf().GetTokenCtx("nnrf-disc", "NRF") + ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_DISC, models.NfType_NRF) if err != nil { return nil, err } @@ -24,6 +24,7 @@ func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfT result, rsp, rspErr := client.NFInstancesStoreApi.SearchNFInstances(ctx, targetNfType, requestNfType, ¶m) + if rspErr != nil { return nil, fmt.Errorf("NFInstancesStoreApi Response error: %+w", rspErr) } diff --git a/internal/sbi/consumer/nf_management.go b/internal/sbi/consumer/nf_management.go index 3a156c4..816ae73 100644 --- a/internal/sbi/consumer/nf_management.go +++ b/internal/sbi/consumer/nf_management.go @@ -1,7 +1,6 @@ package consumer import ( - "context" "fmt" "net/http" "strings" @@ -40,9 +39,14 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil configuration.SetBasePath(nrfUri) client := Nnrf_NFManagement.NewAPIClient(configuration) + ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF) + if err != nil { + return "", "", err + } + var res *http.Response for { - nf, resTmp, err := client.NFInstanceIDDocumentApi.RegisterNFInstance(context.TODO(), nfInstanceId, profile) + nf, resTmp, err := client.NFInstanceIDDocumentApi.RegisterNFInstance(ctx, nfInstanceId, profile) if err != nil || resTmp == nil { logger.ConsumerLog.Errorf("AUSF register to NRF Error[%v]", err) time.Sleep(2 * time.Second) @@ -90,7 +94,7 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil func SendDeregisterNFInstance() (*models.ProblemDetails, error) { logger.ConsumerLog.Infof("Send Deregister NFInstance") - ctx, pd, err := ausf_context.GetSelf().GetTokenCtx("nnrf-nfm", "NRF") + ctx, pd, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NNRF_NFM, models.NfType_NRF) if err != nil { return pd, err } diff --git a/internal/sbi/producer/functions.go b/internal/sbi/producer/functions.go index 2999659..00aba64 100644 --- a/internal/sbi/producer/functions.go +++ b/internal/sbi/producer/functions.go @@ -1,7 +1,6 @@ package producer import ( - "context" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -372,7 +371,13 @@ func sendAuthResultToUDM(id string, authType models.AuthType, success bool, serv authEvent.NfInstanceId = self.GetSelfID() client := createClientToUdmUeau(udmUrl) - _, rsp, confirmAuthErr := client.ConfirmAuthApi.ConfirmAuth(context.Background(), id, authEvent) + + ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM) + if err != nil { + return err + } + + _, rsp, confirmAuthErr := client.ConfirmAuthApi.ConfirmAuth(ctx, id, authEvent) defer func() { if rspCloseErr := rsp.Body.Close(); rspCloseErr != nil { logger.ConsumerLog.Errorf("ConfirmAuth Response cannot close: %v", rspCloseErr) diff --git a/internal/sbi/producer/ue_authentication.go b/internal/sbi/producer/ue_authentication.go index 0fbf190..29291c7 100644 --- a/internal/sbi/producer/ue_authentication.go +++ b/internal/sbi/producer/ue_authentication.go @@ -2,7 +2,6 @@ package producer import ( "bytes" - "context" "crypto/sha256" "encoding/base64" "encoding/hex" @@ -124,7 +123,13 @@ func UeAuthPostRequestProcedure(updateAuthenticationInfo models.AuthenticationIn udmUrl := getUdmUrl(self.NrfUri) client := createClientToUdmUeau(udmUrl) - authInfoResult, rsp, err := client.GenerateAuthDataApi.GenerateAuthData(context.Background(), supiOrSuci, authInfoReq) + + ctx, _, err := ausf_context.GetSelf().GetTokenCtx(models.ServiceName_NUDM_UEAU, models.NfType_UDM) + if err != nil { + return nil, "", nil + } + + authInfoResult, rsp, err := client.GenerateAuthDataApi.GenerateAuthData(ctx, supiOrSuci, authInfoReq) if err != nil { logger.UeAuthLog.Infoln(err.Error()) var problemDetails models.ProblemDetails diff --git a/internal/sbi/sorprotection/routers.go b/internal/sbi/sorprotection/routers.go index 88f63ae..45264b2 100644 --- a/internal/sbi/sorprotection/routers.go +++ b/internal/sbi/sorprotection/routers.go @@ -15,8 +15,11 @@ import ( "github.com/gin-gonic/gin" + ausf_context "github.com/free5gc/ausf/internal/context" "github.com/free5gc/ausf/internal/logger" + "github.com/free5gc/ausf/internal/util" "github.com/free5gc/ausf/pkg/factory" + "github.com/free5gc/openapi/models" logger_util "github.com/free5gc/util/logger" ) @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine { func AddService(engine *gin.Engine) *gin.RouterGroup { group := engine.Group(factory.AusfSorprotectionResUriPrefix) + routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_SORPROTECTION) + group.Use(func(c *gin.Context) { + routerAuthorizationCheck.Check(c, ausf_context.GetSelf()) + }) + for _, route := range routes { switch route.Method { case "GET": diff --git a/internal/sbi/ueauthentication/routers.go b/internal/sbi/ueauthentication/routers.go index cbf572e..948f8db 100644 --- a/internal/sbi/ueauthentication/routers.go +++ b/internal/sbi/ueauthentication/routers.go @@ -15,8 +15,11 @@ import ( "github.com/gin-gonic/gin" + ausf_context "github.com/free5gc/ausf/internal/context" "github.com/free5gc/ausf/internal/logger" + "github.com/free5gc/ausf/internal/util" "github.com/free5gc/ausf/pkg/factory" + "github.com/free5gc/openapi/models" logger_util "github.com/free5gc/util/logger" ) @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine { func AddService(engine *gin.Engine) *gin.RouterGroup { group := engine.Group(factory.AusfAuthResUriPrefix) + routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_AUTH) + group.Use(func(c *gin.Context) { + routerAuthorizationCheck.Check(c, ausf_context.GetSelf()) + }) + for _, route := range routes { switch route.Method { case "GET": diff --git a/internal/sbi/upuprotection/routers.go b/internal/sbi/upuprotection/routers.go index f439c84..f3affde 100644 --- a/internal/sbi/upuprotection/routers.go +++ b/internal/sbi/upuprotection/routers.go @@ -15,8 +15,11 @@ import ( "github.com/gin-gonic/gin" + ausf_context "github.com/free5gc/ausf/internal/context" "github.com/free5gc/ausf/internal/logger" + "github.com/free5gc/ausf/internal/util" "github.com/free5gc/ausf/pkg/factory" + "github.com/free5gc/openapi/models" logger_util "github.com/free5gc/util/logger" ) @@ -45,6 +48,11 @@ func NewRouter() *gin.Engine { func AddService(engine *gin.Engine) *gin.RouterGroup { group := engine.Group(factory.AusfAuthResUriPrefix) + routerAuthorizationCheck := util.NewRouterAuthorizationCheck(models.ServiceName_NAUSF_UPUPROTECTION) + group.Use(func(c *gin.Context) { + routerAuthorizationCheck.Check(c, ausf_context.GetSelf()) + }) + for _, route := range routes { switch route.Method { case "GET": diff --git a/internal/util/router_auth_check.go b/internal/util/router_auth_check.go new file mode 100644 index 0000000..754fc83 --- /dev/null +++ b/internal/util/router_auth_check.go @@ -0,0 +1,34 @@ +package util + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + ausf_context "github.com/free5gc/ausf/internal/context" + "github.com/free5gc/ausf/internal/logger" + "github.com/free5gc/openapi/models" +) + +type RouterAuthorizationCheck struct { + serviceName models.ServiceName +} + +func NewRouterAuthorizationCheck(serviceName models.ServiceName) *RouterAuthorizationCheck { + return &RouterAuthorizationCheck{ + serviceName: serviceName, + } +} + +func (rac *RouterAuthorizationCheck) Check(c *gin.Context, ausfContext ausf_context.NFContext) { + token := c.Request.Header.Get("Authorization") + err := ausfContext.AuthorizationCheck(token, rac.serviceName) + if err != nil { + logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Unauthorized: %s", err.Error()) + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.Abort() + return + } + + logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Authorized") +} diff --git a/internal/util/router_auth_check_test.go b/internal/util/router_auth_check_test.go new file mode 100644 index 0000000..df1612c --- /dev/null +++ b/internal/util/router_auth_check_test.go @@ -0,0 +1,93 @@ +package util + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "github.com/free5gc/openapi/models" +) + +const ( + Valid = "valid" + Invalid = "invalid" +) + +type mockAUSFContext struct{} + +func newMockAUSFContext() *mockAUSFContext { + return &mockAUSFContext{} +} + +func (m *mockAUSFContext) AuthorizationCheck(token string, serviceName models.ServiceName) error { + if token == Valid { + return nil + } + + return errors.New("invalid token") +} + +func TestRouterAuthorizationCheck_Check(t *testing.T) { + // Mock gin.Context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + var err error + c.Request, err = http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("error on http request: %+v", err) + } + + type Args struct { + token string + } + type Want struct { + statusCode int + } + + tests := []struct { + name string + args Args + want Want + }{ + { + name: "Valid Token", + args: Args{ + token: Valid, + }, + want: Want{ + statusCode: http.StatusOK, + }, + }, + { + name: "Invalid Token", + args: Args{ + token: Invalid, + }, + want: Want{ + statusCode: http.StatusUnauthorized, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + c.Request, err = http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("error on http request: %+v", err) + } + c.Request.Header.Set("Authorization", tt.args.token) + + rac := NewRouterAuthorizationCheck(models.ServiceName("testService")) + rac.Check(c, newMockAUSFContext()) + if w.Code != tt.want.statusCode { + t.Errorf("StatusCode should be %d, but got %d", tt.want.statusCode, w.Code) + } + }) + } +} diff --git a/pkg/factory/config.go b/pkg/factory/config.go index 2545100..983da78 100644 --- a/pkg/factory/config.go +++ b/pkg/factory/config.go @@ -153,8 +153,8 @@ func appendInvalid(err error) error { } func (c *Config) GetVersion() string { - c.RLock() - defer c.RUnlock() + c.RWMutex.RLock() + defer c.RWMutex.RUnlock() if c.Info.Version != "" { return c.Info.Version @@ -163,8 +163,8 @@ func (c *Config) GetVersion() string { } func (c *Config) SetLogEnable(enable bool) { - c.Lock() - defer c.Unlock() + c.RWMutex.Lock() + defer c.RWMutex.Unlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") @@ -178,8 +178,8 @@ func (c *Config) SetLogEnable(enable bool) { } func (c *Config) SetLogLevel(level string) { - c.Lock() - defer c.Unlock() + c.RWMutex.Lock() + defer c.RWMutex.Unlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") @@ -192,8 +192,8 @@ func (c *Config) SetLogLevel(level string) { } func (c *Config) SetLogReportCaller(reportCaller bool) { - c.Lock() - defer c.Unlock() + c.RWMutex.Lock() + defer c.RWMutex.Unlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") @@ -207,8 +207,8 @@ func (c *Config) SetLogReportCaller(reportCaller bool) { } func (c *Config) GetLogEnable() bool { - c.RLock() - defer c.RUnlock() + c.RWMutex.RLock() + defer c.RWMutex.RUnlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") return false @@ -217,8 +217,8 @@ func (c *Config) GetLogEnable() bool { } func (c *Config) GetLogLevel() string { - c.RLock() - defer c.RUnlock() + c.RWMutex.RLock() + defer c.RWMutex.RUnlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") return "info" @@ -227,8 +227,8 @@ func (c *Config) GetLogLevel() string { } func (c *Config) GetLogReportCaller() bool { - c.RLock() - defer c.RUnlock() + c.RWMutex.RLock() + defer c.RWMutex.RUnlock() if c.Logger == nil { logger.CfgLog.Warnf("Logger should not be nil") return false