diff --git a/go.mod b/go.mod index cf79ef7..f589ab7 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.17 require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/evanphx/json-patch v0.5.2 - github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6 + github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693 github.com/free5gc/util v1.0.5-0.20231001095115-433858e5be94 github.com/gin-gonic/gin v1.9.1 github.com/google/uuid v1.3.0 + github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.8.1 github.com/urfave/cli v1.22.5 gopkg.in/yaml.v2 v2.4.0 @@ -36,7 +37,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/pkg/errors v0.9.1 // indirect github.com/russross/blackfriday/v2 v2.0.1 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/tim-ywliu/nested-logrus-formatter v1.3.2 // indirect diff --git a/go.sum b/go.sum index 7c7251a..779caed 100644 --- a/go.sum +++ b/go.sum @@ -60,10 +60,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= -github.com/free5gc/openapi v1.0.7-0.20231112094355-a96c3450377e h1:mXnoioq+fxpChliDl5Uy+m6+Hm7iWrJPZo9mi6BijHE= -github.com/free5gc/openapi v1.0.7-0.20231112094355-a96c3450377e/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA= -github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6 h1:8P/wOkTAQMgZJe9pUUNSTE5PWeAdlMrsU9kLsI+VAVE= -github.com/free5gc/openapi v1.0.7-0.20231216094313-e15a4ff046f6/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA= +github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693 h1:gFyYBsErQAkx4OVHXYqjO0efO9gPWydQavQcjU0CkHY= +github.com/free5gc/openapi v1.0.7-0.20240117084712-52ad99299693/go.mod h1:qv9KqEucoZSeENPRFGxfTe+33ZWYyiYFx1Rj+H0DoWA= github.com/free5gc/util v1.0.5-0.20231001095115-433858e5be94 h1:tNylIqH/m5Kq+3KuC+jjXGl06Y6EmM8yq61ZUgNrPBY= github.com/free5gc/util v1.0.5-0.20231001095115-433858e5be94/go.mod h1:aMszJZbCkcg5xaGgzya+55jz+OPMsJqPLq5Z3fWDFPE= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= diff --git a/internal/context/context.go b/internal/context/context.go index 48564e7..5899b1e 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -43,6 +43,12 @@ func Init() { nssfContext.NrfUri = fmt.Sprintf("%s://%s:%d", models.UriScheme_HTTPS, nssfContext.RegisterIPv4, 29510) } +type NFContext interface { + AuthorizationCheck(token, serviceName string) error +} + +var _ NFContext = &NSSFContext{} + type NSSFContext struct { NfId string Name string @@ -130,12 +136,22 @@ func GetSelf() *NSSFContext { return &nssfContext } -func (c *NSSFContext) GetTokenCtx(scope, targetNF string) ( +func (c *NSSFContext) GetTokenCtx(scope string, targetNF models.NfType) ( context.Context, *models.ProblemDetails, error, ) { if !c.OAuth2Required { return context.TODO(), nil, nil } - return oauth.GetTokenCtx(models.NfType_NSSF, - c.NfId, c.NrfUri, scope, targetNF) + return oauth.GetTokenCtx(models.NfType_NSSF, targetNF, + c.NfId, c.NrfUri, scope) +} + +func (c *NSSFContext) AuthorizationCheck(token, serviceName string) error { + if !c.OAuth2Required { + logger.UtilLog.Debugf("NSSFContext::AuthorizationCheck: OAuth2 not required\n") + return nil + } + + logger.UtilLog.Debugf("NSSFContext::AuthorizationCheck: token[%s] serviceName[%s]\n", token, serviceName) + return oauth.VerifyOAuth(token, serviceName, c.NrfCertPem) } diff --git a/internal/sbi/consumer/nf_management.go b/internal/sbi/consumer/nf_management.go index 8f5e3e2..1274e46 100644 --- a/internal/sbi/consumer/nf_management.go +++ b/internal/sbi/consumer/nf_management.go @@ -93,7 +93,7 @@ func SendDeregisterNFInstance() (*models.ProblemDetails, error) { var err error - ctx, pd, err := nssf_context.GetSelf().GetTokenCtx("nnrf-nfm", "NRF") + ctx, pd, err := nssf_context.GetSelf().GetTokenCtx("nnrf-nfm", models.NfType_NRF) if err != nil { return pd, err } diff --git a/internal/sbi/nssaiavailability/routers.go b/internal/sbi/nssaiavailability/routers.go index 27cc22a..ee285f7 100644 --- a/internal/sbi/nssaiavailability/routers.go +++ b/internal/sbi/nssaiavailability/routers.go @@ -15,11 +15,16 @@ import ( "github.com/gin-gonic/gin" + nssf_context "github.com/free5gc/nssf/internal/context" "github.com/free5gc/nssf/internal/logger" + "github.com/free5gc/nssf/internal/util" "github.com/free5gc/nssf/pkg/factory" + "github.com/free5gc/openapi/models" logger_util "github.com/free5gc/util/logger" ) +const serviceName string = string(models.ServiceName_NNSSF_NSSAIAVAILABILITY) + // Route is the information for every URI. type Route struct { // Name is the name of this Route. @@ -45,6 +50,11 @@ func NewRouter() *gin.Engine { func AddService(engine *gin.Engine) *gin.RouterGroup { group := engine.Group(factory.NssfNssaiavailResUriPrefix) + routerAuthorizationCheck := util.NewRouterAuthorizationCheck(serviceName) + group.Use(func(c *gin.Context) { + routerAuthorizationCheck.Check(c, nssf_context.GetSelf()) + }) + for _, route := range routes { switch route.Method { case "GET": diff --git a/internal/sbi/nsselection/routers.go b/internal/sbi/nsselection/routers.go index 43fd1cb..c347505 100644 --- a/internal/sbi/nsselection/routers.go +++ b/internal/sbi/nsselection/routers.go @@ -15,11 +15,16 @@ import ( "github.com/gin-gonic/gin" + nssf_context "github.com/free5gc/nssf/internal/context" "github.com/free5gc/nssf/internal/logger" + "github.com/free5gc/nssf/internal/util" "github.com/free5gc/nssf/pkg/factory" + "github.com/free5gc/openapi/models" logger_util "github.com/free5gc/util/logger" ) +const serviceName string = string(models.ServiceName_NNSSF_NSSELECTION) + // Route is the information for every URI. type Route struct { // Name is the name of this Route. @@ -45,6 +50,11 @@ func NewRouter() *gin.Engine { func AddService(engine *gin.Engine) *gin.RouterGroup { group := engine.Group(factory.NssfNsselectResUriPrefix) + routerAuthorizationCheck := util.NewRouterAuthorizationCheck(serviceName) + group.Use(func(c *gin.Context) { + routerAuthorizationCheck.Check(c, nssf_context.GetSelf()) + }) + for _, route := range routes { switch route.Method { case "GET": diff --git a/internal/util/router_auth_check.go b/internal/util/router_auth_check.go new file mode 100644 index 0000000..1943d64 --- /dev/null +++ b/internal/util/router_auth_check.go @@ -0,0 +1,33 @@ +package util + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + nssf_context "github.com/free5gc/nssf/internal/context" + "github.com/free5gc/nssf/internal/logger" +) + +type RouterAuthorizationCheck struct { + serviceName string +} + +func NewRouterAuthorizationCheck(serviceName string) *RouterAuthorizationCheck { + return &RouterAuthorizationCheck{ + serviceName: serviceName, + } +} + +func (rac *RouterAuthorizationCheck) Check(c *gin.Context, nssfContext nssf_context.NFContext) { + token := c.Request.Header.Get("Authorization") + err := nssfContext.AuthorizationCheck(token, rac.serviceName) + if err != nil { + logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Unauthorized: %s", err.Error()) + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.Abort() + return + } + + logger.UtilLog.Debugf("RouterAuthorizationCheck: Check Authorized") +} diff --git a/internal/util/router_auth_check_test.go b/internal/util/router_auth_check_test.go new file mode 100644 index 0000000..bf4ecf0 --- /dev/null +++ b/internal/util/router_auth_check_test.go @@ -0,0 +1,91 @@ +package util + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +const ( + Valid = "valid" + Invalid = "invalid" +) + +type mockNSSFContext struct{} + +func newMockNSSFContext() *mockNSSFContext { + return &mockNSSFContext{} +} + +func (m *mockNSSFContext) AuthorizationCheck(token string, serviceName string) error { + if token == Valid { + return nil + } + + return errors.New("invalid token") +} + +func TestRouterAuthorizationCheck_Check(t *testing.T) { + // Mock gin.Context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + var err error + c.Request, err = http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("error on http request: %+v", err) + } + + type Args struct { + token string + } + type Want struct { + statusCode int + } + + tests := []struct { + name string + args Args + want Want + }{ + { + name: "Valid Token", + args: Args{ + token: Valid, + }, + want: Want{ + statusCode: http.StatusOK, + }, + }, + { + name: "Invalid Token", + args: Args{ + token: Invalid, + }, + want: Want{ + statusCode: http.StatusUnauthorized, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w = httptest.NewRecorder() + c, _ = gin.CreateTestContext(w) + c.Request, err = http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("error on http request: %+v", err) + } + c.Request.Header.Set("Authorization", tt.args.token) + + rac := NewRouterAuthorizationCheck("testService") + rac.Check(c, newMockNSSFContext()) + if w.Code != tt.want.statusCode { + t.Errorf("StatusCode should be %d, but got %d", tt.want.statusCode, w.Code) + } + }) + } +}