From 5911ff741e5fbcdf6d1701300731b4445891a31e Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Fri, 3 Feb 2023 22:59:00 +0300 Subject: [PATCH 1/4] Fix 56 and 57 issues. Add tests --- cmd/api-firewall/internal/handlers/openapi.go | 15 +- cmd/api-firewall/tests/main_test.go | 458 ++++++ go.mod | 25 +- go.sum | 45 +- internal/platform/validator/internal.go | 7 +- .../platform/validator/req_resp_decoder.go | 219 ++- .../validator/req_resp_decoder_test.go | 1401 +++++++++++++++++ .../platform/validator/req_resp_encoder.go | 30 +- .../validator/req_resp_encoder_test.go | 43 + .../platform/validator/validate_request.go | 233 ++- .../validator/validate_request_test.go | 224 +++ .../platform/validator/validate_response.go | 98 +- .../validator/validate_response_test.go | 215 +++ internal/platform/web/response.go | 34 + 14 files changed, 2955 insertions(+), 92 deletions(-) create mode 100644 internal/platform/validator/req_resp_decoder_test.go create mode 100644 internal/platform/validator/req_resp_encoder_test.go create mode 100644 internal/platform/validator/validate_request_test.go create mode 100644 internal/platform/validator/validate_response_test.go diff --git a/cmd/api-firewall/internal/handlers/openapi.go b/cmd/api-firewall/internal/handlers/openapi.go index 7781d62..c547e02 100644 --- a/cmd/api-firewall/internal/handlers/openapi.go +++ b/cmd/api-firewall/internal/handlers/openapi.go @@ -1,11 +1,9 @@ package handlers import ( - "bytes" "context" "errors" "fmt" - "io" "net/http" "strings" @@ -277,18 +275,27 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { // Prepare http response headers respHeader := http.Header{} - ctx.Request.Header.VisitAll(func(k, v []byte) { + ctx.Response.Header.VisitAll(func(k, v []byte) { sk := string(k) sv := string(v) respHeader.Set(sk, sv) }) + responseBodyReader, err := web.GetResponseBodyUncompressed(ctx) + if err != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("response body decompression error") + return err + } + responseValidationInput := &openapi3filter.ResponseValidationInput{ RequestValidationInput: requestValidationInput, Status: ctx.Response.StatusCode(), Header: respHeader, - Body: io.NopCloser(bytes.NewReader(ctx.Response.Body())), + Body: responseBodyReader, Options: &openapi3filter.Options{ ExcludeRequestBody: false, ExcludeResponseBody: false, diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index 78cfd47..dda9fd9 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -2,6 +2,8 @@ package tests import ( "bytes" + "compress/flate" + "compress/gzip" "encoding/json" "fmt" "io" @@ -13,8 +15,10 @@ import ( "testing" "time" + "github.com/andybalholm/brotli" "github.com/getkin/kin-openapi/openapi3" "github.com/golang/mock/gomock" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" "github.com/wallarm/api-firewall/cmd/api-firewall/internal/handlers" @@ -150,6 +154,35 @@ paths: - petstore_auth: - read - write + /test/headers/request: + get: + summary: Get Request to test Request Headers validation + parameters: + - in: header + name: X-Request-Test + schema: + type: string + format: uuid + required: true + responses: + 200: + description: Ok + content: { } + /test/headers/response: + get: + summary: Get Request to test Response Headers validation + responses: + 200: + description: Ok + headers: + X-Response-Test: + schema: + type: string + format: uuid + required: true + 401: + description: Unauthorized + content: {} components: securitySchemes: petstore_auth: @@ -173,6 +206,9 @@ const ( testDeniedToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzb21lIjoicGF5bG9hZDk5OTk5ODUifQ.S9P-DEiWg7dlI81rLjnJWCA6h9Q4ewTizxrsxOPGmNA" testShadowAPIendpoint = "/shadowAPItest" + + testRequestHeader = "X-Request-Test" + testResponseHeader = "X-Response-Test" ) type ServiceTests struct { @@ -185,6 +221,45 @@ type ServiceTests struct { shadowAPI *shadowAPI.MockChecker } +func compressFlate(data []byte) ([]byte, error) { + var b bytes.Buffer + w, err := flate.NewWriter(&b, 9) + if err != nil { + return nil, err + } + if _, err = w.Write(data); err != nil { + return nil, err + } + if err = w.Close(); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func compressBrotli(data []byte) ([]byte, error) { + var b bytes.Buffer + w := brotli.NewWriterLevel(&b, brotli.BestCompression) + if _, err := w.Write(data); err != nil { + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func compressGzip(data []byte) ([]byte, error) { + var b bytes.Buffer + w := gzip.NewWriter(&b) + if _, err := w.Write(data); err != nil { + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return b.Bytes(), nil +} + // POST /test/signup <- 200 // POST /test/shadow <- 200 func TestBasic(t *testing.T) { @@ -249,6 +324,13 @@ func TestBasic(t *testing.T) { t.Run("oauthJWTRS256", apifwTests.testOauthJWTRS256) t.Run("oauthJWTHS256", apifwTests.testOauthJWTHS256) + t.Run("requestHeaders", apifwTests.testRequestHeaders) + t.Run("responseHeaders", apifwTests.testResponseHeaders) + + t.Run("responseBodyCompressionGzip", apifwTests.testResponseBodyCompressionGzip) + t.Run("responseBodyCompressionBr", apifwTests.testResponseBodyCompressionBr) + t.Run("responseBodyCompressionDeflate", apifwTests.testResponseBodyCompressionDeflate) + } func (s *ServiceTests) testBlockMode(t *testing.T) { @@ -1303,3 +1385,379 @@ func (s *ServiceTests) testOauthJWTHS256(t *testing.T) { } } + +func (s *ServiceTests) testRequestHeaders(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + + xReqTestValue := uuid.New() + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/headers/request") + req.Header.SetMethod("GET") + req.Header.Set(testRequestHeader, xReqTestValue.String()) + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without a required header + req.Header.Del(testRequestHeader) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) testResponseHeaders(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + + xRespTestValue := uuid.New() + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/headers/response") + req.Header.SetMethod("GET") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.Set(testResponseHeader, xRespTestValue.String()) + resp.Header.SetContentType("application/json") + resp.SetBody([]byte("{\"status\":\"success\"}")) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request without a required header + resp.Header.Del(testResponseHeader) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) testResponseBodyCompressionGzip(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.Header.Set("Content-Encoding", "gzip") + + // compress using gzip + body, err := compressGzip([]byte("{\"status\":\"success\"}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with wrong JSON in response + + // compress using gzip + body, err = compressGzip([]byte("{\"status\": 123}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) testResponseBodyCompressionBr(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.Header.Set("Content-Encoding", "br") + + // compress using brotli + body, err := compressBrotli([]byte("{\"status\":\"success\"}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with wrong JSON in response + + // compress using brotli + body, err = compressBrotli([]byte("{\"status\": 123}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} + +func (s *ServiceTests) testResponseBodyCompressionDeflate(t *testing.T) { + + var cfg = config.APIFWConfiguration{ + RequestValidation: "BLOCK", + ResponseValidation: "BLOCK", + CustomBlockStatusCode: 403, + AddValidationStatusHeader: false, + ShadowAPI: config.ShadowAPI{ + ExcludeList: []int{404, 401}, + }, + } + + handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + + p, err := json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + + if err != nil { + t.Fatal(err) + } + + req := fasthttp.AcquireRequest() + req.SetRequestURI("/test/signup") + req.Header.SetMethod("POST") + req.SetBodyStream(bytes.NewReader(p), -1) + req.Header.SetContentType("application/json") + + resp := fasthttp.AcquireResponse() + resp.SetStatusCode(fasthttp.StatusOK) + resp.Header.SetContentType("application/json") + resp.Header.Set("Content-Encoding", "deflate") + + // compress using flate + body, err := compressFlate([]byte("{\"status\":\"success\"}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } + + // Repeat request with wrong JSON in response + + // compress using flate + body, err = compressFlate([]byte("{\"status\": 123}")) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) + + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } + + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } + +} diff --git a/go.mod b/go.mod index 2fdf7cb..23435fb 100644 --- a/go.mod +++ b/go.mod @@ -5,35 +5,40 @@ go 1.19 require ( github.com/ardanlabs/conf v1.5.0 github.com/dgraph-io/ristretto v0.1.1 - github.com/fasthttp/router v1.4.14 - github.com/getkin/kin-openapi v0.110.0 + github.com/fasthttp/router v1.4.15 + github.com/getkin/kin-openapi v0.112.0 github.com/go-playground/validator v9.31.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/mock v1.6.0 + github.com/google/uuid v1.3.0 github.com/karlseguin/ccache/v2 v2.0.8 github.com/pkg/errors v0.9.1 github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d github.com/sirupsen/logrus v1.9.0 - github.com/valyala/fasthttp v1.43.0 - github.com/valyala/fastjson v1.6.3 - golang.org/x/exp v0.0.0-20221212164502-fae10dda9338 + github.com/stretchr/testify v1.8.1 + github.com/valyala/fasthttp v1.44.0 + github.com/valyala/fastjson v1.6.4 + golang.org/x/exp v0.0.0-20230127193734-31bee513bff7 ) require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/dustin/go-humanize v1.0.0 // indirect - github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/swag v0.22.3 // indirect - github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect github.com/golang/glog v1.0.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/invopop/yaml v0.2.0 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/klauspost/compress v1.15.13 // indirect + github.com/klauspost/compress v1.15.15 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.3.0 // indirect + golang.org/x/sys v0.4.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 5de7682..a2dd52a 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/ardanlabs/conf v1.5.0/go.mod h1:ILsMo9dMqYzCxDjDXTiwMI0IgxOJd0MOiucbQ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -12,19 +13,22 @@ github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWa github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= -github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/fasthttp/router v1.4.14 h1:+W65VCKgyI4BZszhDiCRfONoFieePZIoQ7D8vGhiuzM= -github.com/fasthttp/router v1.4.14/go.mod h1:+svLaOvqF9Lc0yjX9aHAD4NUMf+mggLPOT4UMdS6fjM= -github.com/getkin/kin-openapi v0.110.0 h1:1GnJALxsltcSzCMqgtqKlLhYQeULv3/jesmV2sC5qE0= -github.com/getkin/kin-openapi v0.110.0/go.mod h1:QtwUNt0PAAgIIBEvFWYfB7dfngxtAaqCX1zYHMZDeK8= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fasthttp/router v1.4.15 h1:ERaILezYX6ks1I+Z2v5qY4vqiQKnujauo9nV6M+HIOg= +github.com/fasthttp/router v1.4.15/go.mod h1:NFNlTCilbRVkeLc+E5JDkcxUdkpiJGKDL8Zy7Ey2JTI= +github.com/getkin/kin-openapi v0.112.0 h1:lnLXx3bAG53EJVI4E/w0N8i1Y/vUZUEsnrXkgnfn7/Y= +github.com/getkin/kin-openapi v0.112.0/go.mod h1:QtwUNt0PAAgIIBEvFWYfB7dfngxtAaqCX1zYHMZDeK8= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= +github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= @@ -38,6 +42,7 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= @@ -51,13 +56,15 @@ github.com/karlseguin/ccache/v2 v2.0.8/go.mod h1:2BDThcfQMf/c0jnZowt16eW405XIqZP github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003 h1:vJ0Snvo+SLMY72r5J4sEfkuE7AFbixEP2qRbEcum/wA= github.com/karlseguin/expect v1.0.2-0.20190806010014-778a5f0c6003/go.mod h1:zNBxMY8P21owkeogJELCLeHIt+voOSduHYTFUbwRAV8= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0= -github.com/klauspost/compress v1.15.13/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -87,11 +94,10 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.42.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY= -github.com/valyala/fasthttp v1.43.0 h1:Gy4sb32C98fbzVWZlTM1oTMdLWGyvxR03VhM6cBIU4g= -github.com/valyala/fasthttp v1.43.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY= -github.com/valyala/fastjson v1.6.3 h1:tAKFnnwmeMGPbwJ7IwxcTPCNr3uIzoIj3/Fh90ra4xc= -github.com/valyala/fastjson v1.6.3/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/valyala/fasthttp v1.44.0 h1:R+gLUhldIsfg1HokMuQjdQ5bh9nuXHPIfvkYUu9eR5Q= +github.com/valyala/fasthttp v1.44.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY= +github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ= +github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlVjOeLOBz+CPAI8dnbqNSVwUwRrkp7vQ= github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= @@ -99,8 +105,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20221212164502-fae10dda9338 h1:OvjRkcNHnf6/W5FZXSxODbxwD+X7fspczG7Jn/xQVD4= -golang.org/x/exp v0.0.0-20221212164502-fae10dda9338/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230127193734-31bee513bff7 h1:pXR8mGh4q8ooBT7HXruL4Xa2IxoL8XZ6lOgXY/0Ryg8= +golang.org/x/exp v0.0.0-20230127193734-31bee513bff7/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -119,8 +125,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -136,6 +142,7 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/platform/validator/internal.go b/internal/platform/validator/internal.go index 38ef56b..8a40050 100644 --- a/internal/platform/validator/internal.go +++ b/internal/platform/validator/internal.go @@ -40,7 +40,12 @@ func convertToMap(v *fastjson.Value) interface{} { } return a case fastjson.TypeNumber: - return v.GetFloat64() + valueInt := v.GetInt64() + valueFloat := v.GetFloat64() + if valueFloat == float64(int(valueFloat)) { + return valueInt + } + return valueFloat case fastjson.TypeString: return string(v.GetStringBytes()) case fastjson.TypeTrue, fastjson.TypeFalse: diff --git a/internal/platform/validator/req_resp_decoder.go b/internal/platform/validator/req_resp_decoder.go index 3e40e34..6cec077 100644 --- a/internal/platform/validator/req_resp_decoder.go +++ b/internal/platform/validator/req_resp_decoder.go @@ -1,11 +1,12 @@ package validator import ( + "archive/zip" + "bytes" + "encoding/csv" "encoding/json" "errors" "fmt" - "github.com/getkin/kin-openapi/openapi3filter" - "github.com/valyala/fastjson" "io" "io/ioutil" "mime" @@ -16,9 +17,11 @@ import ( "strconv" "strings" + "github.com/valyala/fastjson" "gopkg.in/yaml.v3" "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" ) // ParseErrorKind describes a kind of ParseError. @@ -186,8 +189,17 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) } outSchema = mt.Schema.Value + unmarshal := func(encoded string, paramSchema *openapi3.SchemaRef) (decoded interface{}, err error) { + if err = json.Unmarshal([]byte(encoded), &decoded); err != nil { + if paramSchema != nil && paramSchema.Value.Type != "object" { + decoded, err = encoded, nil + } + } + return + } + if len(values) == 1 { - if err = json.Unmarshal([]byte(values[0]), &outValue); err != nil { + if outValue, err = unmarshal(values[0], mt.Schema); err != nil { err = fmt.Errorf("error unmarshaling parameter %q", param.Name) return } @@ -195,7 +207,7 @@ func defaultContentParameterDecoder(param *openapi3.Parameter, values []string) outArray := make([]interface{}, 0, len(values)) for _, v := range values { var item interface{} - if err = json.Unmarshal([]byte(v), &item); err != nil { + if item, err = unmarshal(v, outSchema.Items); err != nil { err = fmt.Errorf("error unmarshaling parameter %q", param.Name) return } @@ -322,9 +334,12 @@ func decodeValue(dec valueDecoder, param string, sm *openapi3.SerializationMetho case *pathParamDecoder: _, found = vDecoder.pathParams[param] case *urlValuesDecoder: + if schema.Value.Pattern != "" { + return dec.DecodePrimitive(param, sm, schema) + } _, found = vDecoder.values[param] case *headerParamDecoder: - _, found = vDecoder.header[param] + _, found = vDecoder.header[http.CanonicalHeaderKey(param)] case *cookieParamDecoder: _, err := vDecoder.req.Cookie(param) found = err != http.ErrNoCookie @@ -488,6 +503,10 @@ func (d *urlValuesDecoder) DecodePrimitive(param string, sm *openapi3.Serializat // HTTP request does not contain a value of the target query parameter. return nil, ok, nil } + + if schema.Value.Type == "" && schema.Value.Pattern != "" { + return values[0], ok, nil + } val, err := parsePrimitive(values[0], schema) return val, ok, err } @@ -514,10 +533,91 @@ func (d *urlValuesDecoder) DecodeArray(param string, sm *openapi3.SerializationM } values = strings.Split(values[0], delim) } - val, err := parseArray(values, schema) + val, err := d.parseArray(values, sm, schema) return val, ok, err } +// parseArray returns an array that contains items from a raw array. +// Every item is parsed as a primitive value. +// The function returns an error when an error happened while parse array's items. +func (d *urlValuesDecoder) parseArray(raw []string, sm *openapi3.SerializationMethod, schemaRef *openapi3.SchemaRef) ([]interface{}, error) { + var value []interface{} + + for i, v := range raw { + item, err := d.parseValue(v, schemaRef.Value.Items) + if err != nil { + if v, ok := err.(*ParseError); ok { + return nil, &ParseError{path: []interface{}{i}, Cause: v} + } + return nil, fmt.Errorf("item %d: %w", i, err) + } + + // If the items are nil, then the array is nil. There shouldn't be case where some values are actual primitive + // values and some are nil values. + if item == nil { + return nil, nil + } + value = append(value, item) + } + return value, nil +} + +func (d *urlValuesDecoder) parseValue(v string, schema *openapi3.SchemaRef) (interface{}, error) { + if len(schema.Value.AllOf) > 0 { + var value interface{} + var err error + for _, sr := range schema.Value.AllOf { + value, err = d.parseValue(v, sr) + if value == nil || err != nil { + break + } + } + return value, err + } + + if len(schema.Value.AnyOf) > 0 { + var value interface{} + var err error + for _, sr := range schema.Value.AnyOf { + if value, err = d.parseValue(v, sr); err == nil { + return value, nil + } + } + + return nil, err + } + + if len(schema.Value.OneOf) > 0 { + isMatched := 0 + var value interface{} + var err error + for _, sr := range schema.Value.OneOf { + result, err := d.parseValue(v, sr) + if err == nil { + value = result + isMatched++ + } + } + if isMatched == 1 { + return value, nil + } else if isMatched > 1 { + return nil, fmt.Errorf("decoding oneOf failed: %d schemas matched", isMatched) + } else if isMatched == 0 { + return nil, fmt.Errorf("decoding oneOf failed: %d schemas matched", isMatched) + } + + return nil, err + } + + if schema.Value.Not != nil { + // TODO(decode not): handle decoding "not" JSON Schema + return nil, errors.New("not implemented: decoding 'not'") + } + + return parsePrimitive(v, schema) + +} + func (d *urlValuesDecoder) DecodeObject(param string, sm *openapi3.SerializationMethod, schema *openapi3.SchemaRef) (map[string]interface{}, bool, error) { var propsFn func(url.Values) (map[string]string, error) switch sm.Style { @@ -656,7 +756,7 @@ func (d *cookieParamDecoder) DecodePrimitive(param string, sm *openapi3.Serializ return nil, found, nil } if err != nil { - return nil, found, fmt.Errorf("decoding param %q: %s", param, err) + return nil, found, fmt.Errorf("decoding param %q: %w", param, err) } val, err := parsePrimitive(cookie.Value, schema) @@ -675,7 +775,7 @@ func (d *cookieParamDecoder) DecodeArray(param string, sm *openapi3.Serializatio return nil, found, nil } if err != nil { - return nil, found, fmt.Errorf("decoding param %q: %s", param, err) + return nil, found, fmt.Errorf("decoding param %q: %w", param, err) } val, err := parseArray(strings.Split(cookie.Value, ","), schema) return val, found, err @@ -693,7 +793,7 @@ func (d *cookieParamDecoder) DecodeObject(param string, sm *openapi3.Serializati return nil, found, nil } if err != nil { - return nil, found, fmt.Errorf("decoding param %q: %s", param, err) + return nil, found, fmt.Errorf("decoding param %q: %w", param, err) } props, err := propsFromString(cookie.Value, ",", ",") if err != nil { @@ -755,7 +855,7 @@ func makeObject(props map[string]string, schema *openapi3.SchemaRef) (map[string if v, ok := err.(*ParseError); ok { return nil, &ParseError{path: []interface{}{propName}, Cause: v} } - return nil, fmt.Errorf("property %q: %s", propName, err) + return nil, fmt.Errorf("property %q: %w", propName, err) } obj[propName] = value } @@ -773,7 +873,7 @@ func parseArray(raw []string, schemaRef *openapi3.SchemaRef) ([]interface{}, err if v, ok := err.(*ParseError); ok { return nil, &ParseError{path: []interface{}{i}, Cause: v} } - return nil, fmt.Errorf("item %d: %s", i, err) + return nil, fmt.Errorf("item %d: %w", i, err) } // If the items are nil, then the array is nil. There shouldn't be case where some values are actual primitive @@ -788,14 +888,21 @@ func parseArray(raw []string, schemaRef *openapi3.SchemaRef) ([]interface{}, err // parsePrimitive returns a value that is created by parsing a source string to a primitive type // that is specified by a schema. The function returns nil when the source string is empty. -// The function panics when a schema has a non primitive type. +// The function panics when a schema has a non-primitive type. func parsePrimitive(raw string, schema *openapi3.SchemaRef) (interface{}, error) { if raw == "" { return nil, nil } switch schema.Value.Type { case "integer": - v, err := strconv.ParseFloat(raw, 64) + if schema.Value.Format == "int32" { + v, err := strconv.ParseInt(raw, 0, 32) + if err != nil { + return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} + } + return int32(v), nil + } + v, err := strconv.ParseInt(raw, 0, 64) if err != nil { return nil, &ParseError{Kind: KindInvalidFormat, Value: raw, Reason: "an invalid " + schema.Value.Type, Cause: err.(*strconv.NumError).Err} } @@ -897,14 +1004,17 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, } func init() { - RegisterBodyDecoder("text/plain", plainBodyDecoder) RegisterBodyDecoder("application/json", jsonBodyDecoder) - RegisterBodyDecoder("application/x-yaml", yamlBodyDecoder) - RegisterBodyDecoder("application/yaml", yamlBodyDecoder) + RegisterBodyDecoder("application/json-patch+json", jsonBodyDecoder) + RegisterBodyDecoder("application/octet-stream", FileBodyDecoder) RegisterBodyDecoder("application/problem+json", jsonBodyDecoder) RegisterBodyDecoder("application/x-www-form-urlencoded", urlencodedBodyDecoder) + RegisterBodyDecoder("application/x-yaml", yamlBodyDecoder) + RegisterBodyDecoder("application/yaml", yamlBodyDecoder) + RegisterBodyDecoder("application/zip", zipFileBodyDecoder) RegisterBodyDecoder("multipart/form-data", multipartBodyDecoder) - RegisterBodyDecoder("application/octet-stream", FileBodyDecoder) + RegisterBodyDecoder("text/csv", csvBodyDecoder) + RegisterBodyDecoder("text/plain", plainBodyDecoder) } func plainBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { @@ -929,7 +1039,7 @@ func jsonBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema return nil, &ParseError{Kind: KindInvalidFormat, Cause: err} } - return parsedDoc, nil + return convertToMap(parsedDoc), nil } func yamlBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { @@ -1055,7 +1165,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S if v, ok := err.(*ParseError); ok { return nil, &ParseError{path: []interface{}{name}, Cause: v} } - return nil, fmt.Errorf("part %s: %s", name, err) + return nil, fmt.Errorf("part %s: %w", name, err) } values[name] = append(values[name], value) } @@ -1094,3 +1204,74 @@ func FileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.Schema } return string(data), nil } + +// zipFileBodyDecoder is a body decoder that decodes a zip file body to a string. +func zipFileBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + buff := bytes.NewBuffer([]byte{}) + size, err := io.Copy(buff, body) + if err != nil { + return nil, err + } + + zr, err := zip.NewReader(bytes.NewReader(buff.Bytes()), size) + if err != nil { + return nil, err + } + + const bufferSize = 256 + content := make([]byte, 0, bufferSize*len(zr.File)) + buffer := make([]byte /*0,*/, bufferSize) + + for _, f := range zr.File { + err := func() error { + rc, err := f.Open() + if err != nil { + return err + } + defer func() { + _ = rc.Close() + }() + + for { + n, err := rc.Read(buffer) + if 0 < n { + content = append(content, buffer...) + } + if err == io.EOF { + break + } + if err != nil { + return err + } + } + + return nil + }() + + if err != nil { + return nil, err + } + } + + return string(content), nil +} + +// csvBodyDecoder is a body decoder that decodes a csv body to a string. +func csvBodyDecoder(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (interface{}, error) { + r := csv.NewReader(body) + + var content string + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + content += strings.Join(record, ",") + "\n" + } + + return content, nil +} diff --git a/internal/platform/validator/req_resp_decoder_test.go b/internal/platform/validator/req_resp_decoder_test.go new file mode 100644 index 0000000..c7ea5e4 --- /dev/null +++ b/internal/platform/validator/req_resp_decoder_test.go @@ -0,0 +1,1401 @@ +package validator + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/valyala/fastjson" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "reflect" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + legacyrouter "github.com/getkin/kin-openapi/routers/legacy" + "github.com/stretchr/testify/require" +) + +func TestDecodeParameter(t *testing.T) { + var ( + boolPtr = func(b bool) *bool { return &b } + explode = boolPtr(true) + noExplode = boolPtr(false) + arrayOf = func(items *openapi3.SchemaRef) *openapi3.SchemaRef { + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "array", Items: items}} + } + objectOf = func(args ...interface{}) *openapi3.SchemaRef { + s := &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "object", Properties: make(map[string]*openapi3.SchemaRef)}} + if len(args)%2 != 0 { + panic("invalid arguments. must be an even number of arguments") + } + for i := 0; i < len(args)/2; i++ { + propName := args[i*2].(string) + propSchema := args[i*2+1].(*openapi3.SchemaRef) + s.Value.Properties[propName] = propSchema + } + return s + } + + integerSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "integer"}} + numberSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "number"}} + booleanSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "boolean"}} + stringSchema = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: "string"}} + allofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AllOf: []*openapi3.SchemaRef{ + integerSchema, + numberSchema, + }}} + anyofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + AnyOf: []*openapi3.SchemaRef{ + integerSchema, + stringSchema, + }}} + oneofSchema = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + OneOf: []*openapi3.SchemaRef{ + booleanSchema, + integerSchema, + }}} + arraySchema = arrayOf(stringSchema) + objectSchema = objectOf("id", stringSchema, "name", stringSchema) + ) + + type testCase struct { + name string + param *openapi3.Parameter + path string + query string + header string + cookie string + want interface{} + found bool + err error + } + + testGroups := []struct { + name string + testCases []testCase + }{ + { + name: "path primitive", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: noExplode, Schema: stringSchema}, + path: "/foo", + want: "foo", + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: explode, Schema: stringSchema}, + path: "/foo", + want: "foo", + found: true, + }, + { + name: "label", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: stringSchema}, + path: "/.foo", + want: "foo", + found: true, + }, + { + name: "label invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: stringSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "label explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: stringSchema}, + path: "/.foo", + want: "foo", + found: true, + }, + { + name: "label explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: stringSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "matrix", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: stringSchema}, + path: "/;param=foo", + want: "foo", + found: true, + }, + { + name: "matrix invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: stringSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "matrix explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: stringSchema}, + path: "/;param=foo", + want: "foo", + found: true, + }, + { + name: "matrix explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: stringSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: stringSchema}, + path: "/foo", + want: "foo", + found: true, + }, + { + name: "string", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: stringSchema}, + path: "/foo", + want: "foo", + found: true, + }, + { + name: "integer", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: integerSchema}, + path: "/1", + want: int64(1), + found: true, + }, + { + name: "integer invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: integerSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "number", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: numberSchema}, + path: "/1.1", + want: 1.1, + found: true, + }, + { + name: "number invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: numberSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "boolean", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: booleanSchema}, + path: "/true", + want: true, + found: true, + }, + { + name: "boolean invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: booleanSchema}, + path: "/foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + }, + }, + { + name: "path array", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: noExplode, Schema: arraySchema}, + path: "/foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: explode, Schema: arraySchema}, + path: "/foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "label", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: arraySchema}, + path: "/.foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "label invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: arraySchema}, + path: "/foo,bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, + }, + { + name: "label explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: arraySchema}, + path: "/.foo.bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "label explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: arraySchema}, + path: "/foo.bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo.bar"}, + }, + { + name: "matrix", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: arraySchema}, + path: "/;param=foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "matrix invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: arraySchema}, + path: "/foo,bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, + }, + { + name: "matrix explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: arraySchema}, + path: "/;param=foo;param=bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "matrix explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: arraySchema}, + path: "/foo,bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo,bar"}, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: arraySchema}, + path: "/foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "invalid integer items", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: arrayOf(integerSchema)}, + path: "/1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid number items", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: arrayOf(numberSchema)}, + path: "/1.1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid boolean items", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: arrayOf(booleanSchema)}, + path: "/true,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + }, + }, + { + name: "path object", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: noExplode, Schema: objectSchema}, + path: "/id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "simple", Explode: explode, Schema: objectSchema}, + path: "/id=foo,name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "label", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: objectSchema}, + path: "/.id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "label invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: noExplode, Schema: objectSchema}, + path: "/id,foo,name,bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "id,foo,name,bar"}, + }, + { + name: "label explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: objectSchema}, + path: "/.id=foo.name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "label explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "label", Explode: explode, Schema: objectSchema}, + path: "/id=foo.name=bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "id=foo.name=bar"}, + }, + { + name: "matrix", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: objectSchema}, + path: "/;param=id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "matrix invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: noExplode, Schema: objectSchema}, + path: "/id,foo,name,bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "id,foo,name,bar"}, + }, + { + name: "matrix explode", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: objectSchema}, + path: "/;id=foo;name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "matrix explode invalid", + param: &openapi3.Parameter{Name: "param", In: "path", Style: "matrix", Explode: explode, Schema: objectSchema}, + path: "/id=foo;name=bar", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "id=foo;name=bar"}, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: objectSchema}, + path: "/id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "invalid integer prop", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: objectOf("foo", integerSchema)}, + path: "/foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid number prop", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: objectOf("foo", numberSchema)}, + path: "/foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid boolean prop", + param: &openapi3.Parameter{Name: "param", In: "path", Schema: objectOf("foo", booleanSchema)}, + path: "/foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + }, + }, + { + name: "query primitive", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: noExplode, Schema: stringSchema}, + query: "param=foo", + want: "foo", + found: true, + }, + { + name: "form explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: explode, Schema: stringSchema}, + query: "param=foo", + want: "foo", + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: stringSchema}, + query: "param=foo", + want: "foo", + found: true, + }, + { + name: "string", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: stringSchema}, + query: "param=foo", + want: "foo", + found: true, + }, + { + name: "integer", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: integerSchema}, + query: "param=1", + want: int64(1), + found: true, + }, + { + name: "integer invalid", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: integerSchema}, + query: "param=foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "number", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: numberSchema}, + query: "param=1.1", + want: 1.1, + found: true, + }, + { + name: "number invalid", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: numberSchema}, + query: "param=foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "boolean", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: booleanSchema}, + query: "param=true", + want: true, + found: true, + }, + { + name: "boolean invalid", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: booleanSchema}, + query: "param=foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + }, + }, + { + name: "query Allof", + testCases: []testCase{ + { + name: "allofSchema integer and number", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: allofSchema}, + query: "param=1", + want: float64(1), + found: true, + }, + { + name: "allofSchema string", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: allofSchema}, + query: "param=abdf", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "abdf"}, + }, + }, + }, + { + name: "query Anyof", + testCases: []testCase{ + { + name: "anyofSchema integer", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: anyofSchema}, + query: "param=1", + want: int64(1), + found: true, + }, + { + name: "anyofSchema string", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: anyofSchema}, + query: "param=abdf", + want: "abdf", + found: true, + }, + }, + }, + { + name: "query Oneof", + testCases: []testCase{ + { + name: "oneofSchema boolean", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: oneofSchema}, + query: "param=true", + want: true, + found: true, + }, + { + name: "oneofSchema int", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: oneofSchema}, + query: "param=1122", + want: int64(1122), + found: true, + }, + { + name: "oneofSchema string", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: oneofSchema}, + query: "param=abcd", + want: nil, + found: true, + }, + }, + }, + { + name: "query array", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: noExplode, Schema: arraySchema}, + query: "param=foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "form explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: explode, Schema: arraySchema}, + query: "param=foo¶m=bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "spaceDelimited", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: noExplode, Schema: arraySchema}, + query: "param=foo bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "spaceDelimited explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "spaceDelimited", Explode: explode, Schema: arraySchema}, + query: "param=foo¶m=bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "pipeDelimited", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: noExplode, Schema: arraySchema}, + query: "param=foo|bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "pipeDelimited explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "pipeDelimited", Explode: explode, Schema: arraySchema}, + query: "param=foo¶m=bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: arraySchema}, + query: "param=foo¶m=bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "invalid integer items", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: arrayOf(integerSchema)}, + query: "param=1¶m=foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid number items", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: arrayOf(numberSchema)}, + query: "param=1.1¶m=foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid boolean items", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: arrayOf(booleanSchema)}, + query: "param=true¶m=foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + }, + }, + { + name: "query object", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: noExplode, Schema: objectSchema}, + query: "param=id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "form explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "form", Explode: explode, Schema: objectSchema}, + query: "id=foo&name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "deepObject explode", + param: &openapi3.Parameter{Name: "param", In: "query", Style: "deepObject", Explode: explode, Schema: objectSchema}, + query: "param[id]=foo¶m[name]=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: objectSchema}, + query: "id=foo&name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "invalid integer prop", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: objectOf("foo", integerSchema)}, + query: "foo=bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid number prop", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: objectOf("foo", numberSchema)}, + query: "foo=bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid boolean prop", + param: &openapi3.Parameter{Name: "param", In: "query", Schema: objectOf("foo", booleanSchema)}, + query: "foo=bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + }, + }, + { + name: "header primitive", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: noExplode, Schema: stringSchema}, + header: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: explode, Schema: stringSchema}, + header: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: stringSchema}, + header: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "string", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: stringSchema}, + header: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "integer", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: integerSchema}, + header: "X-Param:1", + want: int64(1), + found: true, + }, + { + name: "integer invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: integerSchema}, + header: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "number", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: numberSchema}, + header: "X-Param:1.1", + want: 1.1, + found: true, + }, + { + name: "number invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: numberSchema}, + header: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "boolean", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: booleanSchema}, + header: "X-Param:true", + want: true, + found: true, + }, + { + name: "boolean invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: booleanSchema}, + header: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + }, + }, + { + name: "header array", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: noExplode, Schema: arraySchema}, + header: "X-Param:foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: explode, Schema: arraySchema}, + header: "X-Param:foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: arraySchema}, + header: "X-Param:foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "invalid integer items", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: arrayOf(integerSchema)}, + header: "X-Param:1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid number items", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: arrayOf(numberSchema)}, + header: "X-Param:1.1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid boolean items", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: arrayOf(booleanSchema)}, + header: "X-Param:true,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + }, + }, + { + name: "header object", + testCases: []testCase{ + { + name: "simple", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: noExplode, Schema: objectSchema}, + header: "X-Param:id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "simple explode", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Style: "simple", Explode: explode, Schema: objectSchema}, + header: "X-Param:id=foo,name=bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: objectSchema}, + header: "X-Param:id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "valid integer prop", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: integerSchema}, + header: "X-Param:88", + found: true, + want: int64(88), + }, + { + name: "invalid integer prop", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: objectOf("foo", integerSchema)}, + header: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid number prop", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: objectOf("foo", numberSchema)}, + header: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid boolean prop", + param: &openapi3.Parameter{Name: "X-Param", In: "header", Schema: objectOf("foo", booleanSchema)}, + header: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + }, + }, + { + name: "cookie primitive", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: stringSchema}, + cookie: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "form explode", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: explode, Schema: stringSchema}, + cookie: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "default", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: stringSchema}, + cookie: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "string", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: stringSchema}, + cookie: "X-Param:foo", + want: "foo", + found: true, + }, + { + name: "integer", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: integerSchema}, + cookie: "X-Param:1", + want: int64(1), + found: true, + }, + { + name: "integer invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: integerSchema}, + cookie: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "number", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: numberSchema}, + cookie: "X-Param:1.1", + want: 1.1, + found: true, + }, + { + name: "number invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: numberSchema}, + cookie: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + { + name: "boolean", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: booleanSchema}, + cookie: "X-Param:true", + want: true, + found: true, + }, + { + name: "boolean invalid", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Schema: booleanSchema}, + cookie: "X-Param:foo", + found: true, + err: &ParseError{Kind: KindInvalidFormat, Value: "foo"}, + }, + }, + }, + { + name: "cookie array", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: arraySchema}, + cookie: "X-Param:foo,bar", + want: []interface{}{"foo", "bar"}, + found: true, + }, + { + name: "invalid integer items", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: arrayOf(integerSchema)}, + cookie: "X-Param:1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid number items", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: arrayOf(numberSchema)}, + cookie: "X-Param:1.1,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + { + name: "invalid boolean items", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: arrayOf(booleanSchema)}, + cookie: "X-Param:true,foo", + found: true, + err: &ParseError{path: []interface{}{1}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "foo"}}, + }, + }, + }, + { + name: "cookie object", + testCases: []testCase{ + { + name: "form", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: objectSchema}, + cookie: "X-Param:id,foo,name,bar", + want: map[string]interface{}{"id": "foo", "name": "bar"}, + found: true, + }, + { + name: "invalid integer prop", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: objectOf("foo", integerSchema)}, + cookie: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid number prop", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: objectOf("foo", numberSchema)}, + cookie: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + { + name: "invalid boolean prop", + param: &openapi3.Parameter{Name: "X-Param", In: "cookie", Style: "form", Explode: noExplode, Schema: objectOf("foo", booleanSchema)}, + cookie: "X-Param:foo,bar", + found: true, + err: &ParseError{path: []interface{}{"foo"}, Cause: &ParseError{Kind: KindInvalidFormat, Value: "bar"}}, + }, + }, + }, + } + + for _, tg := range testGroups { + t.Run(tg.name, func(t *testing.T) { + for _, tc := range tg.testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "http://test.org/test"+tc.path, nil) + require.NoError(t, err, "failed to create a test request") + + if tc.query != "" { + query := req.URL.Query() + for _, param := range strings.Split(tc.query, "&") { + v := strings.Split(param, "=") + query.Add(v[0], v[1]) + } + req.URL.RawQuery = query.Encode() + } + + if tc.header != "" { + v := strings.Split(tc.header, ":") + req.Header.Add(v[0], v[1]) + } + + if tc.cookie != "" { + v := strings.Split(tc.cookie, ":") + req.AddCookie(&http.Cookie{Name: v[0], Value: v[1]}) + } + + path := "/test" + if tc.path != "" { + path += "/{" + tc.param.Name + "}" + tc.param.Required = true + } + + info := &openapi3.Info{ + Title: "MyAPI", + Version: "0.1", + } + spec := &openapi3.T{OpenAPI: "3.0.0", Info: info} + op := &openapi3.Operation{ + OperationID: "test", + Parameters: []*openapi3.ParameterRef{{Value: tc.param}}, + Responses: openapi3.NewResponses(), + } + spec.AddOperation(path, http.MethodGet, op) + err = spec.Validate(context.Background()) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(spec) + require.NoError(t, err) + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + input := &openapi3filter.RequestValidationInput{Request: req, PathParams: pathParams, Route: route} + got, found, err := decodeStyledParameter(tc.param, input) + + require.Truef(t, found == tc.found, "got found: %t, want found: %t", found, tc.found) + + if tc.err != nil { + require.Error(t, err) + require.Truef(t, matchParseError(err, tc.err), "got error:\n%v\nwant error:\n%v", err, tc.err) + return + } + + require.NoError(t, err) + require.Truef(t, reflect.DeepEqual(got, tc.want), "got %v, want %v", got, tc.want) + }) + } + }) + } +} + +func TestDecodeBody(t *testing.T) { + boolPtr := func(b bool) *bool { return &b } + + urlencodedForm := make(url.Values) + urlencodedForm.Set("a", "a1") + urlencodedForm.Set("b", "10") + urlencodedForm.Add("c", "c1") + urlencodedForm.Add("c", "c2") + + urlencodedSpaceDelim := make(url.Values) + urlencodedSpaceDelim.Set("a", "a1") + urlencodedSpaceDelim.Set("b", "10") + urlencodedSpaceDelim.Add("c", "c1 c2") + + urlencodedPipeDelim := make(url.Values) + urlencodedPipeDelim.Set("a", "a1") + urlencodedPipeDelim.Set("b", "10") + urlencodedPipeDelim.Add("c", "c1|c2") + + d, err := json.Marshal(map[string]interface{}{"d1": "d1"}) + require.NoError(t, err) + multipartForm, multipartFormMime, err := newTestMultipartForm([]*testFormPart{ + {name: "a", contentType: "text/plain", data: strings.NewReader("a1")}, + {name: "b", contentType: "application/json", data: strings.NewReader("10")}, + {name: "c", contentType: "text/plain", data: strings.NewReader("c1")}, + {name: "c", contentType: "text/plain", data: strings.NewReader("c2")}, + {name: "d", contentType: "application/json", data: bytes.NewReader(d)}, + {name: "f", contentType: "application/octet-stream", data: strings.NewReader("foo"), filename: "f1"}, + {name: "g", data: strings.NewReader("g1")}, + }) + require.NoError(t, err) + + multipartFormExtraPart, multipartFormMimeExtraPart, err := newTestMultipartForm([]*testFormPart{ + {name: "a", contentType: "text/plain", data: strings.NewReader("a1")}, + {name: "x", contentType: "text/plain", data: strings.NewReader("x1")}, + }) + require.NoError(t, err) + + multipartAnyAdditionalProps, multipartMimeAnyAdditionalProps, err := newTestMultipartForm([]*testFormPart{ + {name: "a", contentType: "text/plain", data: strings.NewReader("a1")}, + {name: "x", contentType: "text/plain", data: strings.NewReader("x1")}, + }) + require.NoError(t, err) + + multipartAdditionalProps, multipartMimeAdditionalProps, err := newTestMultipartForm([]*testFormPart{ + {name: "a", contentType: "text/plain", data: strings.NewReader("a1")}, + {name: "x", contentType: "text/plain", data: strings.NewReader("x1")}, + }) + require.NoError(t, err) + + multipartAdditionalPropsErr, multipartMimeAdditionalPropsErr, err := newTestMultipartForm([]*testFormPart{ + {name: "a", contentType: "text/plain", data: strings.NewReader("a1")}, + {name: "x", contentType: "text/plain", data: strings.NewReader("x1")}, + {name: "y", contentType: "text/plain", data: strings.NewReader("y1")}, + }) + require.NoError(t, err) + + testCases := []struct { + name string + mime string + body io.Reader + schema *openapi3.Schema + encoding map[string]*openapi3.Encoding + want interface{} + wantErr error + }{ + { + name: prefixUnsupportedCT, + mime: "application/xml", + wantErr: &ParseError{Kind: KindUnsupportedFormat}, + }, + { + name: "invalid body data", + mime: "application/json", + body: strings.NewReader("invalid"), + wantErr: &ParseError{Kind: KindInvalidFormat}, + }, + { + name: "plain text", + mime: "text/plain", + body: strings.NewReader("text"), + want: "text", + }, + { + name: "json", + mime: "application/json", + body: strings.NewReader("\"foo\""), + want: "foo", + }, + { + name: "x-yaml", + mime: "application/x-yaml", + body: strings.NewReader("foo"), + want: "foo", + }, + { + name: "yaml", + mime: "application/yaml", + body: strings.NewReader("foo"), + want: "foo", + }, + { + name: "urlencoded form", + mime: "application/x-www-form-urlencoded", + body: strings.NewReader(urlencodedForm.Encode()), + schema: openapi3.NewObjectSchema(). + WithProperty("a", openapi3.NewStringSchema()). + WithProperty("b", openapi3.NewIntegerSchema()). + WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())), + want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}}, + }, + { + name: "urlencoded space delimited", + mime: "application/x-www-form-urlencoded", + body: strings.NewReader(urlencodedSpaceDelim.Encode()), + schema: openapi3.NewObjectSchema(). + WithProperty("a", openapi3.NewStringSchema()). + WithProperty("b", openapi3.NewIntegerSchema()). + WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())), + encoding: map[string]*openapi3.Encoding{ + "c": {Style: openapi3.SerializationSpaceDelimited, Explode: boolPtr(false)}, + }, + want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}}, + }, + { + name: "urlencoded pipe delimited", + mime: "application/x-www-form-urlencoded", + body: strings.NewReader(urlencodedPipeDelim.Encode()), + schema: openapi3.NewObjectSchema(). + WithProperty("a", openapi3.NewStringSchema()). + WithProperty("b", openapi3.NewIntegerSchema()). + WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())), + encoding: map[string]*openapi3.Encoding{ + "c": {Style: openapi3.SerializationPipeDelimited, Explode: boolPtr(false)}, + }, + want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}}, + }, + { + name: "multipart", + mime: multipartFormMime, + body: multipartForm, + schema: openapi3.NewObjectSchema(). + WithProperty("a", openapi3.NewStringSchema()). + WithProperty("b", openapi3.NewIntegerSchema()). + WithProperty("c", openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema())). + WithProperty("d", openapi3.NewObjectSchema().WithProperty("d1", openapi3.NewStringSchema())). + WithProperty("f", openapi3.NewStringSchema().WithFormat("binary")). + WithProperty("g", openapi3.NewStringSchema()), + want: map[string]interface{}{"a": "a1", "b": int64(10), "c": []interface{}{"c1", "c2"}, "d": map[string]interface{}{"d1": "d1"}, "f": "foo", "g": "g1"}, + }, + { + name: "multipartExtraPart", + mime: multipartFormMimeExtraPart, + body: multipartFormExtraPart, + schema: openapi3.NewObjectSchema(). + WithProperty("a", openapi3.NewStringSchema()), + want: map[string]interface{}{"a": "a1"}, + wantErr: &ParseError{Kind: KindOther}, + }, + { + name: "multipartAnyAdditionalProperties", + mime: multipartMimeAnyAdditionalProps, + body: multipartAnyAdditionalProps, + schema: openapi3.NewObjectSchema(). + WithAnyAdditionalProperties(). + WithProperty("a", openapi3.NewStringSchema()), + want: map[string]interface{}{"a": "a1"}, + }, + { + name: "multipartWithAdditionalProperties", + mime: multipartMimeAdditionalProps, + body: multipartAdditionalProps, + schema: openapi3.NewObjectSchema(). + WithAdditionalProperties(openapi3.NewObjectSchema(). + WithProperty("x", openapi3.NewStringSchema())). + WithProperty("a", openapi3.NewStringSchema()), + want: map[string]interface{}{"a": "a1", "x": "x1"}, + }, + { + name: "multipartWithAdditionalPropertiesError", + mime: multipartMimeAdditionalPropsErr, + body: multipartAdditionalPropsErr, + schema: openapi3.NewObjectSchema(). + WithAdditionalProperties(openapi3.NewObjectSchema(). + WithProperty("x", openapi3.NewStringSchema())). + WithProperty("a", openapi3.NewStringSchema()), + want: map[string]interface{}{"a": "a1", "x": "x1"}, + wantErr: &ParseError{Kind: KindOther}, + }, + { + name: "file", + mime: "application/octet-stream", + body: strings.NewReader("foo"), + want: "foo", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + h := make(http.Header) + h.Set(headerCT, tc.mime) + var schemaRef *openapi3.SchemaRef + if tc.schema != nil { + schemaRef = tc.schema.NewRef() + } + encFn := func(name string) *openapi3.Encoding { + if tc.encoding == nil { + return nil + } + return tc.encoding[name] + } + _, got, err := decodeBody(tc.body, h, schemaRef, encFn, &fastjson.Parser{}) + + if tc.wantErr != nil { + require.Error(t, err) + require.Truef(t, matchParseError(err, tc.wantErr), "got error:\n%v\nwant error:\n%v", err, tc.wantErr) + return + } + + require.NoError(t, err) + require.Truef(t, reflect.DeepEqual(got, tc.want), "got %v, want %v", got, tc.want) + }) + } +} + +type testFormPart struct { + name string + contentType string + data io.Reader + filename string +} + +func newTestMultipartForm(parts []*testFormPart) (io.Reader, string, error) { + form := &bytes.Buffer{} + w := multipart.NewWriter(form) + defer w.Close() + + for _, p := range parts { + var disp string + if p.filename == "" { + disp = fmt.Sprintf("form-data; name=%q", p.name) + } else { + disp = fmt.Sprintf("form-data; name=%q; filename=%q", p.name, p.filename) + } + + h := make(textproto.MIMEHeader) + h.Set(headerCT, p.contentType) + h.Set("Content-Disposition", disp) + pw, err := w.CreatePart(h) + if err != nil { + return nil, "", err + } + if _, err = io.Copy(pw, p.data); err != nil { + return nil, "", err + } + } + return form, w.FormDataContentType(), nil +} + +func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { + var decoder BodyDecoder + decoder = func(body io.Reader, h http.Header, schema *openapi3.SchemaRef, encFn EncodingFn, jsonParser *fastjson.Parser) (decoded interface{}, err error) { + var data []byte + if data, err = ioutil.ReadAll(body); err != nil { + return + } + return strings.Split(string(data), ","), nil + } + contentType := "application/csv" + h := make(http.Header) + h.Set(headerCT, contentType) + + originalDecoder := RegisteredBodyDecoder(contentType) + require.Nil(t, originalDecoder) + + RegisterBodyDecoder(contentType, decoder) + require.Equal(t, fmt.Sprintf("%v", decoder), fmt.Sprintf("%v", RegisteredBodyDecoder(contentType))) + + body := strings.NewReader("foo,bar") + schema := openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema()).NewRef() + encFn := func(string) *openapi3.Encoding { return nil } + _, got, err := decodeBody(body, h, schema, encFn, &fastjson.Parser{}) + + require.NoError(t, err) + require.Equal(t, []string{"foo", "bar"}, got) + + UnregisterBodyDecoder(contentType) + + originalDecoder = RegisteredBodyDecoder(contentType) + require.Nil(t, originalDecoder) + + _, _, err = decodeBody(body, h, schema, encFn, &fastjson.Parser{}) + require.Equal(t, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: prefixUnsupportedCT + ` "application/csv"`, + }, err) +} + +func matchParseError(got, want error) bool { + wErr, ok := want.(*ParseError) + if !ok { + return false + } + gErr, ok := got.(*ParseError) + if !ok { + return false + } + if wErr.Kind != gErr.Kind { + return false + } + if !reflect.DeepEqual(wErr.Value, gErr.Value) { + return false + } + if !reflect.DeepEqual(wErr.Path(), gErr.Path()) { + return false + } + if wErr.Cause != nil { + return matchParseError(gErr.Cause, wErr.Cause) + } + return true +} diff --git a/internal/platform/validator/req_resp_encoder.go b/internal/platform/validator/req_resp_encoder.go index 472c687..870f06e 100644 --- a/internal/platform/validator/req_resp_encoder.go +++ b/internal/platform/validator/req_resp_encoder.go @@ -16,8 +16,34 @@ func encodeBody(body interface{}, mediaType string) ([]byte, error) { return encoder(body) } -type bodyEncoder func(body interface{}) ([]byte, error) +type BodyEncoder func(body interface{}) ([]byte, error) -var bodyEncoders = map[string]bodyEncoder{ +var bodyEncoders = map[string]BodyEncoder{ "application/json": json.Marshal, } + +func RegisterBodyEncoder(contentType string, encoder BodyEncoder) { + if contentType == "" { + panic("contentType is empty") + } + if encoder == nil { + panic("encoder is not defined") + } + bodyEncoders[contentType] = encoder +} + +// This call is not thread-safe: body encoders should not be created/destroyed by multiple goroutines. +func UnregisterBodyEncoder(contentType string) { + if contentType == "" { + panic("contentType is empty") + } + delete(bodyEncoders, contentType) +} + +// RegisteredBodyEncoder returns the registered body encoder for the given content type. +// +// If no encoder was registered for the given content type, nil is returned. +// This call is not thread-safe: body encoders should not be created/destroyed by multiple goroutines. +func RegisteredBodyEncoder(contentType string) BodyEncoder { + return bodyEncoders[contentType] +} diff --git a/internal/platform/validator/req_resp_encoder_test.go b/internal/platform/validator/req_resp_encoder_test.go new file mode 100644 index 0000000..8f746f0 --- /dev/null +++ b/internal/platform/validator/req_resp_encoder_test.go @@ -0,0 +1,43 @@ +package validator + +import ( + "fmt" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRegisterAndUnregisterBodyEncoder(t *testing.T) { + var encoder BodyEncoder + encoder = func(body interface{}) (data []byte, err error) { + return []byte(strings.Join(body.([]string), ",")), nil + } + contentType := "text/csv" + h := make(http.Header) + h.Set(headerCT, contentType) + + originalEncoder := RegisteredBodyEncoder(contentType) + require.Nil(t, originalEncoder) + + RegisterBodyEncoder(contentType, encoder) + require.Equal(t, fmt.Sprintf("%v", encoder), fmt.Sprintf("%v", RegisteredBodyEncoder(contentType))) + + body := []string{"foo", "bar"} + got, err := encodeBody(body, contentType) + + require.NoError(t, err) + require.Equal(t, []byte("foo,bar"), got) + + UnregisterBodyEncoder(contentType) + + originalEncoder = RegisteredBodyEncoder(contentType) + require.Nil(t, originalEncoder) + + _, err = encodeBody(body, contentType) + require.Equal(t, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: prefixUnsupportedCT + ` "text/csv"`, + }, err) +} diff --git a/internal/platform/validator/validate_request.go b/internal/platform/validator/validate_request.go index 8ec7445..15ae15f 100644 --- a/internal/platform/validator/validate_request.go +++ b/internal/platform/validator/validate_request.go @@ -4,14 +4,25 @@ import ( "bytes" "context" "fmt" + "io" + "net/http" + "sort" + "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + "github.com/pkg/errors" "github.com/valyala/fastjson" - "io" - "net/http" ) -const prefixInvalidCT = "header Content-Type has unexpected value" +// ErrAuthenticationServiceMissing is returned when no authentication service +// is defined for the request validator +var ErrAuthenticationServiceMissing = errors.New("missing AuthenticationFunc") + +// ErrInvalidRequired is returned when a required value of a parameter or request body is not defined. +var ErrInvalidRequired = errors.New("value is required but missing") + +// ErrInvalidEmptyValue is returned when a value of a parameter or request body is empty while it's not allowed. +var ErrInvalidEmptyValue = errors.New("empty value is not allowed") // ValidateRequest is used to validate the given input according to previous // loaded OpenAPIv3 spec. If the input does not match the OpenAPIv3 spec, a @@ -19,11 +30,8 @@ const prefixInvalidCT = "header Content-Type has unexpected value" // // Note: One can tune the behavior of uniqueItems: true verification // by registering a custom function with openapi3.RegisterArrayUniqueItemsChecker -func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidationInput, jsonParser *fastjson.Parser) error { - var ( - err error - me openapi3.MultiError - ) +func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidationInput, jsonParser *fastjson.Parser) (err error) { + var me openapi3.MultiError options := input.Options if options == nil { @@ -42,8 +50,8 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio security = &route.Spec.Security } if security != nil { - if err = openapi3filter.ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { - return err + if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { + return } if err != nil { @@ -60,8 +68,8 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio } } - if err = openapi3filter.ValidateParameter(ctx, input, parameter); err != nil && !options.MultiError { - return err + if err = ValidateParameter(ctx, input, parameter); err != nil && !options.MultiError { + return } if err != nil { @@ -71,8 +79,8 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio // For each parameter of the Operation for _, parameter := range operationParameters { - if err = openapi3filter.ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError { - return err + if err = ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError { + return } if err != nil { @@ -84,7 +92,7 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio requestBody := operation.RequestBody if requestBody != nil && !options.ExcludeRequestBody { if err = ValidateRequestBody(ctx, input, requestBody.Value, jsonParser); err != nil && !options.MultiError { - return err + return } if err != nil { @@ -95,10 +103,95 @@ func ValidateRequest(ctx context.Context, input *openapi3filter.RequestValidatio if len(me) > 0 { return me } + return +} + +// ValidateParameter validates a parameter's value by JSON schema. +// The function returns RequestError with a ParseError cause when unable to parse a value. +// The function returns RequestError with ErrInvalidRequired cause when a value of a required parameter is not defined. +// The function returns RequestError with ErrInvalidEmptyValue cause when a value of a required parameter is not defined. +// The function returns RequestError with a openapi3.SchemaError cause when a value is invalid by JSON schema. +func ValidateParameter(ctx context.Context, input *openapi3filter.RequestValidationInput, parameter *openapi3.Parameter) error { + if parameter.Schema == nil && parameter.Content == nil { + // We have no schema for the parameter. Assume that everything passes + // a schema-less check, but this could also be an error. The OpenAPI + // validation allows this to happen. + return nil + } + + options := input.Options + if options == nil { + options = openapi3filter.DefaultOptions + } + + var value interface{} + var err error + var found bool + var schema *openapi3.Schema + + // Validation will ensure that we either have content or schema. + if parameter.Content != nil { + if value, schema, found, err = decodeContentParameter(parameter, input); err != nil { + return &openapi3filter.RequestError{Input: input, Parameter: parameter, Err: err} + } + } else { + if value, found, err = decodeStyledParameter(parameter, input); err != nil { + return &openapi3filter.RequestError{Input: input, Parameter: parameter, Err: err} + } + schema = parameter.Schema.Value + } + + // Set default value if needed + if !options.SkipSettingDefaults && value == nil && schema != nil && schema.Default != nil { + value = schema.Default + req := input.Request + switch parameter.In { + case openapi3.ParameterInPath: + // Path parameters are required. + // Next check `parameter.Required && !found` will catch this. + case openapi3.ParameterInQuery: + q := req.URL.Query() + q.Add(parameter.Name, fmt.Sprintf("%v", value)) + req.URL.RawQuery = q.Encode() + case openapi3.ParameterInHeader: + req.Header.Add(parameter.Name, fmt.Sprintf("%v", value)) + case openapi3.ParameterInCookie: + req.AddCookie(&http.Cookie{ + Name: parameter.Name, + Value: fmt.Sprintf("%v", value), + }) + } + } + // Validate a parameter's value and presence. + if parameter.Required && !found { + return &openapi3filter.RequestError{Input: input, Parameter: parameter, Reason: ErrInvalidRequired.Error(), Err: ErrInvalidRequired} + } + + if isNilValue(value) { + if !parameter.AllowEmptyValue && found { + return &openapi3filter.RequestError{Input: input, Parameter: parameter, Reason: ErrInvalidEmptyValue.Error(), Err: ErrInvalidEmptyValue} + } + return nil + } + if schema == nil { + // A parameter's schema is not defined so skip validation of a parameter's value. + return nil + } + + var opts []openapi3.SchemaValidationOption + if options.MultiError { + opts = make([]openapi3.SchemaValidationOption, 0, 1) + opts = append(opts, openapi3.MultiErrors()) + } + if err = schema.VisitJSON(value, opts...); err != nil { + return &openapi3filter.RequestError{Input: input, Parameter: parameter, Err: err} + } return nil } +const prefixInvalidCT = "header Content-Type has unexpected value" + // ValidateRequestBody validates data of a request's body. // // The function returns RequestError with ErrInvalidRequired cause when a value is required but not defined. @@ -126,7 +219,19 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid } } // Put the data back into the input - req.Body = io.NopCloser(bytes.NewReader(data)) + req.Body = nil + if req.GetBody != nil { + if req.Body, err = req.GetBody(); err != nil { + req.Body = nil + } + } + if req.Body == nil { + req.ContentLength = int64(len(data)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } + req.Body, _ = req.GetBody() // no error return + } } if len(data) == 0 { @@ -171,23 +276,21 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid defaultsSet := false opts := make([]openapi3.SchemaValidationOption, 0, 3) // 3 potential opts here opts = append(opts, openapi3.VisitAsRequest()) - opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true })) + if !options.SkipSettingDefaults { + opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true })) + } if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } - // prepare map[string]interface{} structure for json validation - fastjsonValue, ok := value.(*fastjson.Value) - if ok { - value = convertToMap(fastjsonValue) - } - // Validate JSON with the schema if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { + schemaId := getSchemaIdentifier(contentType.Schema) + schemaId = prependSpaceIfNeeded(schemaId) return &openapi3filter.RequestError{ Input: input, RequestBody: requestBody, - Reason: "doesn't match the schema", + Reason: fmt.Sprintf("doesn't match schema%s", schemaId), Err: err, } } @@ -203,8 +306,88 @@ func ValidateRequestBody(ctx context.Context, input *openapi3filter.RequestValid } } // Put the data back into the input - req.Body = io.NopCloser(bytes.NewReader(data)) + if req.Body != nil { + req.Body.Close() + } + req.ContentLength = int64(len(data)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } + req.Body, _ = req.GetBody() // no error return + } + + return nil +} + +// ValidateSecurityRequirements goes through multiple OpenAPI 3 security +// requirements in order and returns nil on the first valid requirement. +// If no requirement is met, errors are returned in order. +func ValidateSecurityRequirements(ctx context.Context, input *openapi3filter.RequestValidationInput, srs openapi3.SecurityRequirements) error { + if len(srs) == 0 { + return nil + } + var errs []error + for _, sr := range srs { + if err := validateSecurityRequirement(ctx, input, sr); err != nil { + if len(errs) == 0 { + errs = make([]error, 0, len(srs)) + } + errs = append(errs, err) + continue + } + return nil + } + return &openapi3filter.SecurityRequirementsError{ + SecurityRequirements: srs, + Errors: errs, } +} + +// validateSecurityRequirement validates a single OpenAPI 3 security requirement +func validateSecurityRequirement(ctx context.Context, input *openapi3filter.RequestValidationInput, securityRequirement openapi3.SecurityRequirement) error { + doc := input.Route.Spec + securitySchemes := doc.Components.SecuritySchemes + // Ensure deterministic order + names := make([]string, 0, len(securityRequirement)) + for name := range securityRequirement { + names = append(names, name) + } + sort.Strings(names) + + // Get authentication function + options := input.Options + if options == nil { + options = openapi3filter.DefaultOptions + } + f := options.AuthenticationFunc + if f == nil { + return ErrAuthenticationServiceMissing + } + + // For each scheme for the requirement + for _, name := range names { + var securityScheme *openapi3.SecurityScheme + if securitySchemes != nil { + if ref := securitySchemes[name]; ref != nil { + securityScheme = ref.Value + } + } + if securityScheme == nil { + return &openapi3filter.RequestError{ + Input: input, + Err: fmt.Errorf("security scheme %q is not declared", name), + } + } + scopes := securityRequirement[name] + if err := f(ctx, &openapi3filter.AuthenticationInput{ + RequestValidationInput: input, + SecuritySchemeName: name, + SecurityScheme: securityScheme, + Scopes: scopes, + }); err != nil { + return err + } + } return nil } diff --git a/internal/platform/validator/validate_request_test.go b/internal/platform/validator/validate_request_test.go new file mode 100644 index 0000000..069e3e5 --- /dev/null +++ b/internal/platform/validator/validate_request_test.go @@ -0,0 +1,224 @@ +package validator + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func setupTestRouter(t *testing.T, spec string) routers.Router { + t.Helper() + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(loader.Context) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + return router +} + +func TestValidateRequest(t *testing.T) { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /category: + post: + parameters: + - name: category + in: query + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string + category: + type: string + default: Sweets + responses: + '201': + description: Created + security: + - apiKey: [] +components: + securitySchemes: + apiKey: + type: apiKey + name: Api-Key + in: header +` + + router := setupTestRouter(t, spec) + + verifyAPIKeyPresence := func(c context.Context, input *openapi3filter.AuthenticationInput) error { + if input.SecurityScheme.Type == "apiKey" { + var found bool + switch input.SecurityScheme.In { + case "query": + _, found = input.RequestValidationInput.GetQueryParams()[input.SecurityScheme.Name] + case "header": + _, found = input.RequestValidationInput.Request.Header[http.CanonicalHeaderKey(input.SecurityScheme.Name)] + case "cookie": + _, err := input.RequestValidationInput.Request.Cookie(input.SecurityScheme.Name) + found = !errors.Is(err, http.ErrNoCookie) + } + if !found { + return fmt.Errorf("%v not found in %v", input.SecurityScheme.Name, input.SecurityScheme.In) + } + } + return nil + } + + type testRequestBody struct { + SubCategory string `json:"subCategory"` + Category string `json:"category,omitempty"` + } + type args struct { + requestBody *testRequestBody + url string + apiKey string + } + tests := []struct { + name string + args args + expectedModification bool + expectedErr error + }{ + { + name: "Valid request with all fields set", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate", Category: "Food"}, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedModification: false, + expectedErr: nil, + }, + { + name: "Valid request without certain fields", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedModification: true, + expectedErr: nil, + }, + { + name: "Invalid operation params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?invalidCategory=badCookie", + apiKey: "SomeKey", + }, + expectedModification: false, + expectedErr: &openapi3filter.RequestError{}, + }, + { + name: "Invalid request body", + args: args{ + requestBody: nil, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedModification: false, + expectedErr: &openapi3filter.RequestError{}, + }, + { + name: "Invalid security", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?category=cookies", + apiKey: "", + }, + expectedModification: false, + expectedErr: &openapi3filter.SecurityRequirementsError{}, + }, + { + name: "Invalid request body and security", + args: args{ + requestBody: nil, + url: "/category?category=cookies", + apiKey: "", + }, + expectedModification: false, + expectedErr: &openapi3filter.SecurityRequirementsError{}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var requestBody io.Reader + var originalBodySize int + if tc.args.requestBody != nil { + testingBody, err := json.Marshal(tc.args.requestBody) + require.NoError(t, err) + requestBody = bytes.NewReader(testingBody) + originalBodySize = len(testingBody) + } + req, err := http.NewRequest(http.MethodPost, tc.args.url, requestBody) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/json") + if tc.args.apiKey != "" { + req.Header.Add("Api-Key", tc.args.apiKey) + } + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + validationInput := &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + Options: &openapi3filter.Options{ + AuthenticationFunc: verifyAPIKeyPresence, + }, + } + err = ValidateRequest(context.Background(), validationInput, &fastjson.Parser{}) + assert.IsType(t, tc.expectedErr, err, "ValidateRequest(): error = %v, expectedError %v", err, tc.expectedErr) + if tc.expectedErr != nil { + return + } + body, err := io.ReadAll(validationInput.Request.Body) + contentLen := int(validationInput.Request.ContentLength) + bodySize := len(body) + assert.NoError(t, err, "unable to read request body: %v", err) + assert.Equal(t, contentLen, bodySize, "expect ContentLength %d to equal body size %d", contentLen, bodySize) + bodyModified := originalBodySize != bodySize + assert.Equal(t, bodyModified, tc.expectedModification, "expect request body modification happened: %t, expected %t", bodyModified, tc.expectedModification) + + validationInput.Request.Body, err = validationInput.Request.GetBody() + assert.NoError(t, err, "unable to re-generate body by GetBody(): %v", err) + body2, err := io.ReadAll(validationInput.Request.Body) + assert.NoError(t, err, "unable to read request body: %v", err) + assert.Equal(t, body, body2, "body by GetBody() is not matched") + }) + } +} diff --git a/internal/platform/validator/validate_response.go b/internal/platform/validator/validate_response.go index 6b398fc..a8e796b 100644 --- a/internal/platform/validator/validate_response.go +++ b/internal/platform/validator/validate_response.go @@ -4,12 +4,14 @@ import ( "bytes" "context" "fmt" - "github.com/valyala/fastjson" "io" "net/http" + "sort" + "strings" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + "github.com/valyala/fastjson" ) // ValidateResponse is used to validate the given input according to previous @@ -62,6 +64,25 @@ func ValidateResponse(ctx context.Context, input *openapi3filter.ResponseValidat return &openapi3filter.ResponseError{Input: input, Reason: "response has not been resolved"} } + opts := make([]openapi3.SchemaValidationOption, 0, 2) + if options.MultiError { + opts = append(opts, openapi3.MultiErrors()) + } + + headers := make([]string, 0, len(response.Headers)) + for k := range response.Headers { + if k != headerCT { + headers = append(headers, k) + } + } + sort.Strings(headers) + for _, headerName := range headers { + headerRef := response.Headers[headerName] + if err := validateResponseHeader(headerName, headerRef, input, opts); err != nil { + return err + } + } + if options.ExcludeResponseBody { // A user turned off validation of a response's body. return nil @@ -121,25 +142,78 @@ func ValidateResponse(ctx context.Context, input *openapi3filter.ResponseValidat } } - opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here - opts = append(opts, openapi3.VisitAsRequest()) - if options.MultiError { - opts = append(opts, openapi3.MultiErrors()) + // Validate data with the schema. + if err := contentType.Schema.Value.VisitJSON(value, append(opts, openapi3.VisitAsResponse())...); err != nil { + schemaId := getSchemaIdentifier(contentType.Schema) + schemaId = prependSpaceIfNeeded(schemaId) + return &openapi3filter.ResponseError{ + Input: input, + Reason: fmt.Sprintf("response body doesn't match schema%s", schemaId), + Err: err, + } } + return nil +} + +func validateResponseHeader(headerName string, headerRef *openapi3.HeaderRef, input *openapi3filter.ResponseValidationInput, opts []openapi3.SchemaValidationOption) error { + var err error + var decodedValue interface{} + var found bool + var sm *openapi3.SerializationMethod + dec := &headerParamDecoder{header: input.Header} - // prepare map[string]interface{} structure for json validation - fastjsonValue, ok := value.(*fastjson.Value) - if ok { - value = convertToMap(fastjsonValue) + if sm, err = headerRef.Value.SerializationMethod(); err != nil { + return &openapi3filter.ResponseError{ + Input: input, + Reason: fmt.Sprintf("unable to get header %q serialization method", headerName), + Err: err, + } } - // Validate data with the schema. - if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { + if decodedValue, found, err = decodeValue(dec, headerName, sm, headerRef.Value.Schema, headerRef.Value.Required); err != nil { return &openapi3filter.ResponseError{ Input: input, - Reason: "response body doesn't match the schema", + Reason: fmt.Sprintf("unable to decode header %q value", headerName), Err: err, } } + + if found { + if err = headerRef.Value.Schema.Value.VisitJSON(decodedValue, opts...); err != nil { + return &openapi3filter.ResponseError{ + Input: input, + Reason: fmt.Sprintf("response header %q doesn't match schema", headerName), + Err: err, + } + } + } else if headerRef.Value.Required { + return &openapi3filter.ResponseError{ + Input: input, + Reason: fmt.Sprintf("response header %q missing", headerName), + } + } return nil } + +// getSchemaIdentifier gets something by which a schema could be identified. +// A schema by itself doesn't have a true identity field. This function makes +// a best effort to get a value that can fill that void. +func getSchemaIdentifier(schema *openapi3.SchemaRef) string { + var id string + + if schema != nil { + id = strings.TrimSpace(schema.Ref) + } + if id == "" && schema.Value != nil { + id = strings.TrimSpace(schema.Value.Title) + } + + return id +} + +func prependSpaceIfNeeded(value string) string { + if len(value) > 0 { + value = " " + value + } + return value +} diff --git a/internal/platform/validator/validate_response_test.go b/internal/platform/validator/validate_response_test.go new file mode 100644 index 0000000..d1caaae --- /dev/null +++ b/internal/platform/validator/validate_response_test.go @@ -0,0 +1,215 @@ +package validator + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/stretchr/testify/require" +) + +func Test_validateResponseHeader(t *testing.T) { + type args struct { + headerName string + headerRef *openapi3.HeaderRef + } + tests := []struct { + name string + args args + isHeaderPresent bool + headerVals []string + wantErr bool + wantErrMsg string + }{ + { + name: "test required string header with single string value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewStringSchema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"blab"}, + wantErr: false, + }, + { + name: "test required string header with single, empty string value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewStringSchema(), true), + }, + isHeaderPresent: true, + headerVals: []string{""}, + wantErr: true, + wantErrMsg: `response header "X-Blab" doesn't match schema: Value is not nullable`, + }, + { + name: "test optional string header with single string value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewStringSchema(), false), + }, + isHeaderPresent: false, + headerVals: []string{"blab"}, + wantErr: false, + }, + { + name: "test required, but missing string header", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewStringSchema(), true), + }, + isHeaderPresent: false, + headerVals: nil, + wantErr: true, + wantErrMsg: `response header "X-Blab" missing`, + }, + { + name: "test integer header with single integer value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewIntegerSchema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"88"}, + wantErr: false, + }, + { + name: "test integer header with single string value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewIntegerSchema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"blab"}, + wantErr: true, + wantErrMsg: `unable to decode header "X-Blab" value: value blab: an invalid integer: invalid syntax`, + }, + { + name: "test int64 header with single int64 value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewInt64Schema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"88"}, + wantErr: false, + }, + { + name: "test int32 header with single int32 value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewInt32Schema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"88"}, + wantErr: false, + }, + { + name: "test float64 header with single float64 value", + args: args{ + headerName: "X-Blab", + headerRef: newHeaderRef(openapi3.NewFloat64Schema(), true), + }, + isHeaderPresent: true, + headerVals: []string{"88.87"}, + wantErr: false, + }, + { + name: "test integer header with multiple csv integer values", + args: args{ + headerName: "X-blab", + headerRef: newHeaderRef(newArraySchema(openapi3.NewIntegerSchema()), true), + }, + isHeaderPresent: true, + headerVals: []string{"87,88"}, + wantErr: false, + }, + { + name: "test integer header with multiple integer values", + args: args{ + headerName: "X-blab", + headerRef: newHeaderRef(newArraySchema(openapi3.NewIntegerSchema()), true), + }, + isHeaderPresent: true, + headerVals: []string{"87", "88"}, + wantErr: false, + }, + { + name: "test non-typed, nullable header with single string value", + args: args{ + headerName: "X-blab", + headerRef: newHeaderRef(&openapi3.Schema{Nullable: true}, true), + }, + isHeaderPresent: true, + headerVals: []string{"blab"}, + wantErr: false, + }, + { + name: "test required non-typed, nullable header not present", + args: args{ + headerName: "X-blab", + headerRef: newHeaderRef(&openapi3.Schema{Nullable: true}, true), + }, + isHeaderPresent: false, + headerVals: []string{"blab"}, + wantErr: true, + wantErrMsg: `response header "X-blab" missing`, + }, + { + name: "test non-typed, non-nullable header with single string value", + args: args{ + headerName: "X-blab", + headerRef: newHeaderRef(&openapi3.Schema{Nullable: false}, true), + }, + isHeaderPresent: true, + headerVals: []string{"blab"}, + wantErr: true, + wantErrMsg: `response header "X-blab" doesn't match schema: Value is not nullable`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := newInputDefault() + opts := []openapi3.SchemaValidationOption(nil) + if tt.isHeaderPresent { + input.Header = map[string][]string{http.CanonicalHeaderKey(tt.args.headerName): tt.headerVals} + } + + err := validateResponseHeader(tt.args.headerName, tt.args.headerRef, input, opts) + if tt.wantErr { + require.NotEmpty(t, tt.wantErrMsg, "wanted error message is not populated") + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErrMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func newInputDefault() *openapi3filter.ResponseValidationInput { + return &openapi3filter.ResponseValidationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: nil, + PathParams: nil, + Route: nil, + }, + Status: 200, + Header: nil, + Body: io.NopCloser(strings.NewReader(`{}`)), + } +} + +func newHeaderRef(schema *openapi3.Schema, required bool) *openapi3.HeaderRef { + return &openapi3.HeaderRef{Value: &openapi3.Header{Parameter: openapi3.Parameter{Schema: &openapi3.SchemaRef{Value: schema}, Required: required}}} +} + +func newArraySchema(schema *openapi3.Schema) *openapi3.Schema { + arraySchema := openapi3.NewArraySchema() + arraySchema.Items = openapi3.NewSchemaRef("", schema) + + return arraySchema +} diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index c690794..da2d688 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -1,12 +1,46 @@ package web import ( + "bytes" + "compress/flate" "encoding/json" + "io" "net/http" "github.com/valyala/fasthttp" ) +var ( + gzip = []byte("gzip") + deflate = []byte("deflate") + br = []byte("br") +) + +func GetResponseBodyUncompressed(ctx *fasthttp.RequestCtx) (io.ReadCloser, error) { + + bodyBytes := ctx.Response.Body() + bodyReader := io.NopCloser(bytes.NewReader(bodyBytes)) + compression := ctx.Response.Header.ContentEncoding() + + if compression != nil { + for _, sc := range [][]byte{gzip, deflate, br} { + if bytes.Equal(sc, compression) { + var body []byte + var err error + if body, err = ctx.Response.BodyUncompressed(); err != nil { + if bytes.Equal(compression, []byte("deflate")) { + return flate.NewReader(bytes.NewReader(bodyBytes)), nil + } + return nil, err + } + return io.NopCloser(bytes.NewReader(body)), nil + } + } + } + + return bodyReader, nil +} + // Respond converts a Go value to JSON and sends it to the client. func Respond(ctx *fasthttp.RequestCtx, data interface{}, statusCode int) error { // If there is nothing to marshal then set status code and return. From eb063fef7bc08174a1e708247bed7306c7df6af0 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Sat, 4 Feb 2023 10:59:48 +0300 Subject: [PATCH 2/4] Update version --- Makefile | 2 +- demo/docker-compose/docker-compose.yml | 2 +- helm/api-firewall/Chart.yaml | 4 ++-- internal/platform/web/response.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index bc6e1bb..aca57d0 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := 0.6.10 +VERSION := 0.6.11 .DEFAULT_GOAL := build diff --git a/demo/docker-compose/docker-compose.yml b/demo/docker-compose/docker-compose.yml index 00e5ce0..a5d81fc 100644 --- a/demo/docker-compose/docker-compose.yml +++ b/demo/docker-compose/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.8' services: api-firewall: container_name: api-firewall - image: wallarm/api-firewall:v0.6.10 + image: wallarm/api-firewall:v0.6.11 restart: on-failure environment: APIFW_URL: "http://0.0.0.0:8080" diff --git a/helm/api-firewall/Chart.yaml b/helm/api-firewall/Chart.yaml index 1f9a5fc..de28fa9 100644 --- a/helm/api-firewall/Chart.yaml +++ b/helm/api-firewall/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v1 name: api-firewall -version: 0.6.10 -appVersion: 0.6.10 +version: 0.6.11 +appVersion: 0.6.11 description: Wallarm OpenAPI-based API Firewall home: https://github.com/wallarm/api-firewall icon: https://static.wallarm.com/wallarm-logo.svg diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index da2d688..8ca7691 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -28,7 +28,7 @@ func GetResponseBodyUncompressed(ctx *fasthttp.RequestCtx) (io.ReadCloser, error var body []byte var err error if body, err = ctx.Response.BodyUncompressed(); err != nil { - if bytes.Equal(compression, []byte("deflate")) { + if bytes.Equal(compression, deflate) { return flate.NewReader(bytes.NewReader(bodyBytes)), nil } return nil, err From 42705d0da63a698a1bd7f37684f05483cf4fe829 Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Thu, 9 Feb 2023 00:34:53 +0300 Subject: [PATCH 3/4] Update GetResponseBodyUncompressed --- internal/platform/web/response.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index 8ca7691..974868b 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -2,24 +2,27 @@ package web import ( "bytes" - "compress/flate" "encoding/json" + "errors" "io" "net/http" + "github.com/klauspost/compress/flate" + "github.com/klauspost/compress/zlib" "github.com/valyala/fasthttp" ) +// List of the supported compression schemes var ( gzip = []byte("gzip") deflate = []byte("deflate") br = []byte("br") ) +// GetResponseBodyUncompressed function returns the Reader of the uncompressed body func GetResponseBodyUncompressed(ctx *fasthttp.RequestCtx) (io.ReadCloser, error) { bodyBytes := ctx.Response.Body() - bodyReader := io.NopCloser(bytes.NewReader(bodyBytes)) compression := ctx.Response.Header.ContentEncoding() if compression != nil { @@ -28,17 +31,23 @@ func GetResponseBodyUncompressed(ctx *fasthttp.RequestCtx) (io.ReadCloser, error var body []byte var err error if body, err = ctx.Response.BodyUncompressed(); err != nil { - if bytes.Equal(compression, deflate) { + if errors.Is(zlib.ErrHeader, err) && bytes.Equal(compression, deflate) { + // deflate rfc 1951 implementation return flate.NewReader(bytes.NewReader(bodyBytes)), nil } + // got error while body decompression return nil, err } + // body has been successfully uncompressed return io.NopCloser(bytes.NewReader(body)), nil } } + // body compression schema not supported + return nil, fasthttp.ErrContentEncodingUnsupported } - return bodyReader, nil + // body without compression + return io.NopCloser(bytes.NewReader(bodyBytes)), nil } // Respond converts a Go value to JSON and sends it to the client. From 047c65bb1a08a20aab84b935d5e01f373850028c Mon Sep 17 00:00:00 2001 From: Nikolay Tkachenko Date: Fri, 10 Feb 2023 19:42:32 +0300 Subject: [PATCH 4/4] Add request body decompression. Add conf var that allows to delete accept-encoding header. Update tests. --- cmd/api-firewall/internal/handlers/openapi.go | 16 +- cmd/api-firewall/tests/main_test.go | 292 ++++++++---------- internal/config/config.go | 19 +- internal/mid/proxy.go | 31 +- internal/platform/web/response.go | 101 ++++-- 5 files changed, 248 insertions(+), 211 deletions(-) diff --git a/cmd/api-firewall/internal/handlers/openapi.go b/cmd/api-firewall/internal/handlers/openapi.go index c547e02..e0e277a 100644 --- a/cmd/api-firewall/internal/handlers/openapi.go +++ b/cmd/api-firewall/internal/handlers/openapi.go @@ -186,6 +186,19 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { return web.RespondError(ctx, fasthttp.StatusBadRequest, nil) } + // decode request body + requestContentEncoding := string(ctx.Request.Header.ContentEncoding()) + if requestContentEncoding != "" { + req.Body, err = web.GetDecompressedRequestBody(&ctx.Request, requestContentEncoding) + if err != nil { + s.logger.WithFields(logrus.Fields{ + "error": err, + "request_id": fmt.Sprintf("#%016X", ctx.ID()), + }).Error("request body decompression error") + return err + } + } + // Validate request requestValidationInput := &openapi3filter.RequestValidationInput{ Request: &req, @@ -282,7 +295,8 @@ func (s *openapiWaf) openapiWafHandler(ctx *fasthttp.RequestCtx) error { respHeader.Set(sk, sv) }) - responseBodyReader, err := web.GetResponseBodyUncompressed(ctx) + // decode response body + responseBodyReader, err := web.GetDecompressedResponseBody(&ctx.Response, string(ctx.Response.Header.ContentEncoding())) if err != nil { s.logger.WithFields(logrus.Fields{ "error": err, diff --git a/cmd/api-firewall/tests/main_test.go b/cmd/api-firewall/tests/main_test.go index dda9fd9..85d9f13 100644 --- a/cmd/api-firewall/tests/main_test.go +++ b/cmd/api-firewall/tests/main_test.go @@ -5,6 +5,7 @@ import ( "compress/flate" "compress/gzip" "encoding/json" + "errors" "fmt" "io" "net" @@ -195,6 +196,8 @@ components: write: write ` +var testSupportedEncodingSchemas = []string{"gzip", "deflate", "br"} + const ( testOauthBearerToken = "testtesttest" testOauthJWTTokenRS = "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJqd3QudGVzdC5naXRodWIuaW8iLCJzdWIiOiJldmFuZGVyIiwiYXVkIjoibmFpbWlzaCIsImlhdCI6MTYzODUwNjIxNywiZXhwIjozNTMxOTM3ODc1LCJzY29wZSI6InJlYWQgd3JpdGUifQ.MPC35ZX52qWE4AktY1Bs-HVEWUUYrByfRVUSL9GbzZhZfXlfcNkF-qNRK_EDG2eviE4UHb6CFVZeYTsO5MyKg0H3shp79LeZTA2XzCuCZvzAqA7EQrpUKiKof-9af5g3jIRU4YFxvtpp8XxXGHaMvbIy4gqQJ7WEsOksYOytEsbLtsCs880zxCJb1iM4Bu9Q_Nl-wW1NeYSZyHYZP7es7gVvb9Bbm6qYW4qcVbt20pW4dguBGEvUvLM6axqeTZe7JgtqU__uUwkcIS6bu711Y7Zi-TpeZAMp506Wx8qZrhi7Ea0QFZUMjoF0O7jgRtps_BlbqBXNoleMO-kKnSkd6A" @@ -260,6 +263,19 @@ func compressGzip(data []byte) ([]byte, error) { return b.Bytes(), nil } +func compressData(data []byte, encodingSchema string) ([]byte, error) { + switch encodingSchema { + case "br": + return compressBrotli(data) + case "deflate": + return compressFlate(data) + case "gzip": + return compressGzip(data) + } + + return nil, errors.New("encoding schema not supported") +} + // POST /test/signup <- 200 // POST /test/shadow <- 200 func TestBasic(t *testing.T) { @@ -310,7 +326,7 @@ func TestBasic(t *testing.T) { t.Run("basicBlockLogOnlyMode", apifwTests.testBlockLogOnlyMode) t.Run("basicLogOnlyBlockMode", apifwTests.testLogOnlyBlockMode) - t.Run("commonParamters", apifwTests.testCommonParameters) + t.Run("commonParameters", apifwTests.testCommonParameters) t.Run("basicDenylist", apifwTests.testDenylist) t.Run("basicShadowAPI", apifwTests.testShadowAPI) @@ -327,10 +343,8 @@ func TestBasic(t *testing.T) { t.Run("requestHeaders", apifwTests.testRequestHeaders) t.Run("responseHeaders", apifwTests.testResponseHeaders) - t.Run("responseBodyCompressionGzip", apifwTests.testResponseBodyCompressionGzip) - t.Run("responseBodyCompressionBr", apifwTests.testResponseBodyCompressionBr) - t.Run("responseBodyCompressionDeflate", apifwTests.testResponseBodyCompressionDeflate) - + t.Run("reqBodyCompression", apifwTests.testRequestBodyCompression) + t.Run("respBodyCompression", apifwTests.testResponseBodyCompression) } func (s *ServiceTests) testBlockMode(t *testing.T) { @@ -1507,7 +1521,7 @@ func (s *ServiceTests) testResponseHeaders(t *testing.T) { } -func (s *ServiceTests) testResponseBodyCompressionGzip(t *testing.T) { +func (s *ServiceTests) testRequestBodyCompression(t *testing.T) { var cfg = config.APIFWConfiguration{ RequestValidation: "BLOCK", @@ -1521,163 +1535,103 @@ func (s *ServiceTests) testResponseBodyCompressionGzip(t *testing.T) { handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) - p, err := json.Marshal(map[string]interface{}{ - "firstname": "test", - "lastname": "test", - "job": "test", - "email": "test@wallarm.com", - "url": "http://wallarm.com", - }) - - if err != nil { - t.Fatal(err) - } - req := fasthttp.AcquireRequest() req.SetRequestURI("/test/signup") req.Header.SetMethod("POST") - req.SetBodyStream(bytes.NewReader(p), -1) - req.Header.SetContentType("application/json") resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) resp.Header.SetContentType("application/json") - resp.Header.Set("Content-Encoding", "gzip") - - // compress using gzip - body, err := compressGzip([]byte("{\"status\":\"success\"}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) - - reqCtx := fasthttp.RequestCtx{ - Request: *req, - } - - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) - - handler(&reqCtx) - - if reqCtx.Response.StatusCode() != 200 { - t.Errorf("Incorrect response status code. Expected: 200 and got %d", - reqCtx.Response.StatusCode()) - } - - // Repeat request with wrong JSON in response - - // compress using gzip - body, err = compressGzip([]byte("{\"status\": 123}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) - - reqCtx = fasthttp.RequestCtx{ - Request: *req, - } - - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client) - - handler(&reqCtx) + resp.SetBody([]byte("{\"status\":\"success\"}")) - if reqCtx.Response.StatusCode() != 403 { - t.Errorf("Incorrect response status code. Expected: 403 and got %d", - reqCtx.Response.StatusCode()) - } + var p []byte + var err error -} + for _, encSchema := range testSupportedEncodingSchemas { -func (s *ServiceTests) testResponseBodyCompressionBr(t *testing.T) { + p, err = json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "test@wallarm.com", + "url": "http://wallarm.com", + }) + if err != nil { + t.Fatal(err) + } - var cfg = config.APIFWConfiguration{ - RequestValidation: "BLOCK", - ResponseValidation: "BLOCK", - CustomBlockStatusCode: 403, - AddValidationStatusHeader: false, - ShadowAPI: config.ShadowAPI{ - ExcludeList: []int{404, 401}, - }, - } + // compress request body using gzip + reqBodyRaw, err := io.ReadAll(bytes.NewReader(p)) + if err != nil { + t.Fatal(err) + } - handler := handlers.OpenapiProxy(&cfg, s.serverUrl, s.shutdown, s.logger, s.proxy, s.swagRouter, nil, s.shadowAPI) + reqBody, err := compressData(reqBodyRaw, encSchema) + if err != nil { + t.Fatal(err) + } - p, err := json.Marshal(map[string]interface{}{ - "firstname": "test", - "lastname": "test", - "job": "test", - "email": "test@wallarm.com", - "url": "http://wallarm.com", - }) + req.SetBody(reqBody) + req.Header.SetContentEncoding(encSchema) + req.Header.SetContentType("application/json") - if err != nil { - t.Fatal(err) - } + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } - req := fasthttp.AcquireRequest() - req.SetRequestURI("/test/signup") - req.Header.SetMethod("POST") - req.SetBodyStream(bytes.NewReader(p), -1) - req.Header.SetContentType("application/json") + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) - resp := fasthttp.AcquireResponse() - resp.SetStatusCode(fasthttp.StatusOK) - resp.Header.SetContentType("application/json") - resp.Header.Set("Content-Encoding", "br") + handler(&reqCtx) - // compress using brotli - body, err := compressBrotli([]byte("{\"status\":\"success\"}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } - reqCtx := fasthttp.RequestCtx{ - Request: *req, - } + // Repeat request with wrong JSON in request - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) + p, err = json.Marshal(map[string]interface{}{ + "firstname": "test", + "lastname": "test", + "job": "test", + "email": "wrong_email_test", + "url": "http://wallarm.com", + }) - handler(&reqCtx) + // compress request body using gzip + reqBodyRaw, err = io.ReadAll(bytes.NewReader(p)) + if err != nil { + t.Fatal(err) + } - if reqCtx.Response.StatusCode() != 200 { - t.Errorf("Incorrect response status code. Expected: 200 and got %d", - reqCtx.Response.StatusCode()) - } + reqBody, err = compressData(reqBodyRaw, encSchema) + if err != nil { + t.Fatal(err) + } - // Repeat request with wrong JSON in response + req.SetBody(reqBody) - // compress using brotli - body, err = compressBrotli([]byte("{\"status\": 123}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } - reqCtx = fasthttp.RequestCtx{ - Request: *req, - } + s.proxy.EXPECT().Get().Return(s.client, nil) + s.proxy.EXPECT().Put(s.client) - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client) + handler(&reqCtx) - handler(&reqCtx) + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } - if reqCtx.Response.StatusCode() != 403 { - t.Errorf("Incorrect response status code. Expected: 403 and got %d", - reqCtx.Response.StatusCode()) } } -func (s *ServiceTests) testResponseBodyCompressionDeflate(t *testing.T) { +func (s *ServiceTests) testResponseBodyCompression(t *testing.T) { var cfg = config.APIFWConfiguration{ RequestValidation: "BLOCK", @@ -1712,52 +1666,56 @@ func (s *ServiceTests) testResponseBodyCompressionDeflate(t *testing.T) { resp := fasthttp.AcquireResponse() resp.SetStatusCode(fasthttp.StatusOK) resp.Header.SetContentType("application/json") - resp.Header.Set("Content-Encoding", "deflate") - // compress using flate - body, err := compressFlate([]byte("{\"status\":\"success\"}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) + for _, encSchema := range testSupportedEncodingSchemas { - reqCtx := fasthttp.RequestCtx{ - Request: *req, - } + // compress response body using gzip + body, err := compressData([]byte("{\"status\":\"success\"}"), encSchema) + if err != nil { + t.Fatal(err) + } - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client).Return(nil) + resp.SetBody(body) + resp.Header.SetContentEncoding(encSchema) - handler(&reqCtx) + reqCtx := fasthttp.RequestCtx{ + Request: *req, + } - if reqCtx.Response.StatusCode() != 200 { - t.Errorf("Incorrect response status code. Expected: 200 and got %d", - reqCtx.Response.StatusCode()) - } + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client).Return(nil) - // Repeat request with wrong JSON in response + handler(&reqCtx) - // compress using flate - body, err = compressFlate([]byte("{\"status\": 123}")) - if err != nil { - t.Fatal(err) - } - resp.SetBody(body) + if reqCtx.Response.StatusCode() != 200 { + t.Errorf("Incorrect response status code. Expected: 200 and got %d", + reqCtx.Response.StatusCode()) + } - reqCtx = fasthttp.RequestCtx{ - Request: *req, - } + // Repeat request with wrong JSON in response - s.proxy.EXPECT().Get().Return(s.client, nil) - s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) - s.proxy.EXPECT().Put(s.client) + // compress using gzip + body, err = compressData([]byte("{\"status\": 123}"), encSchema) + if err != nil { + t.Fatal(err) + } + resp.SetBody(body) - handler(&reqCtx) + reqCtx = fasthttp.RequestCtx{ + Request: *req, + } - if reqCtx.Response.StatusCode() != 403 { - t.Errorf("Incorrect response status code. Expected: 403 and got %d", - reqCtx.Response.StatusCode()) + s.proxy.EXPECT().Get().Return(s.client, nil) + s.client.EXPECT().Do(gomock.Any(), gomock.Any()).SetArg(1, *resp) + s.proxy.EXPECT().Put(s.client) + + handler(&reqCtx) + + if reqCtx.Response.StatusCode() != 403 { + t.Errorf("Incorrect response status code. Expected: 403 and got %d", + reqCtx.Response.StatusCode()) + } } } diff --git a/internal/config/config.go b/internal/config/config.go index 47929ec..704de6b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,15 +13,16 @@ type TLS struct { } type Server struct { - URL string `conf:"default:http://localhost:3000/v1/" validate:"required,url"` - ClientPoolCapacity int `conf:"default:1000" validate:"gt=0"` - InsecureConnection bool `conf:"default:false"` - RootCA string `conf:""` - MaxConnsPerHost int `conf:"default:512"` - ReadTimeout time.Duration `conf:"default:5s"` - WriteTimeout time.Duration `conf:"default:5s"` - DialTimeout time.Duration `conf:"default:200ms"` - Oauth Oauth + URL string `conf:"default:http://localhost:3000/v1/" validate:"required,url"` + ClientPoolCapacity int `conf:"default:1000" validate:"gt=0"` + InsecureConnection bool `conf:"default:false"` + RootCA string `conf:""` + MaxConnsPerHost int `conf:"default:512"` + ReadTimeout time.Duration `conf:"default:5s"` + WriteTimeout time.Duration `conf:"default:5s"` + DialTimeout time.Duration `conf:"default:200ms"` + DeleteAcceptEncoding bool `conf:"default:false"` + Oauth Oauth } type JWT struct { diff --git a/internal/mid/proxy.go b/internal/mid/proxy.go index 9d7da4e..3c06fd6 100644 --- a/internal/mid/proxy.go +++ b/internal/mid/proxy.go @@ -3,6 +3,7 @@ package mid import ( "bytes" "fmt" + "net/http" "net/url" "github.com/savsgio/gotils/strconv" @@ -18,17 +19,20 @@ const apifwHeaderName = "APIFW-Request-Id" // Connection header field. These are the headers defined by the // obsoleted RFC 2616 (section 13.5.1) and are used for backward // compatibility. -var hopHeaders = []string{ - "Connection", // Connection - "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google - "Keep-Alive", // Keep-Alive - "Proxy-Authenticate", // Proxy-Authenticate - "Proxy-Authorization", // Proxy-Authorization - "Te", // canonicalized version of "TE" - "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 - "Transfer-Encoding", // Transfer-Encoding - "Upgrade", // Upgrade -} +var ( + hopHeaders = []string{ + "Connection", // Connection + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", // Keep-Alive + "Proxy-Authenticate", // Proxy-Authenticate + "Proxy-Authorization", // Proxy-Authorization + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", // Transfer-Encoding + "Upgrade", // Upgrade + } + acHeader = http.CanonicalHeaderKey("Accept-Encoding") +) // Proxy changes request scheme before request func Proxy(cfg *config.APIFWConfiguration, serverUrl *url.URL) web.Middleware { @@ -66,6 +70,11 @@ func Proxy(cfg *config.APIFWConfiguration, serverUrl *url.URL) web.Middleware { ctx.Request.Header.Set("X-Forwarded-For", ctx.RemoteIP().String()) } + // delete Accept-Encoding header + if cfg.Server.DeleteAcceptEncoding { + ctx.Request.Header.Del(acHeader) + } + err := before(ctx) for _, h := range hopHeaders { diff --git a/internal/platform/web/response.go b/internal/platform/web/response.go index 974868b..7f3c4aa 100644 --- a/internal/platform/web/response.go +++ b/internal/platform/web/response.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "golang.org/x/exp/slices" "io" "net/http" @@ -14,33 +15,87 @@ import ( // List of the supported compression schemes var ( - gzip = []byte("gzip") - deflate = []byte("deflate") - br = []byte("br") + supportedEncodings = []string{"gzip", "deflate", "br"} ) -// GetResponseBodyUncompressed function returns the Reader of the uncompressed body -func GetResponseBodyUncompressed(ctx *fasthttp.RequestCtx) (io.ReadCloser, error) { - - bodyBytes := ctx.Response.Body() - compression := ctx.Response.Header.ContentEncoding() - - if compression != nil { - for _, sc := range [][]byte{gzip, deflate, br} { - if bytes.Equal(sc, compression) { - var body []byte - var err error - if body, err = ctx.Response.BodyUncompressed(); err != nil { - if errors.Is(zlib.ErrHeader, err) && bytes.Equal(compression, deflate) { - // deflate rfc 1951 implementation - return flate.NewReader(bytes.NewReader(bodyBytes)), nil - } - // got error while body decompression - return nil, err +//// GetDecompressedBody function returns the Reader of the decompressed body +//func GetDecompressedBody(ctx *fasthttp.RequestCtx) (io.ReadCloser, error) { +// +// bodyBytes := ctx.Response.Body() +// compression := ctx.Response.Header.ContentEncoding() +// +// if compression != nil { +// for _, sc := range [][]byte{gzip, deflate, br} { +// if bytes.Equal(sc, compression) { +// var body []byte +// var err error +// if body, err = ctx.Response.BodyUncompressed(); err != nil { +// if errors.Is(zlib.ErrHeader, err) && bytes.Equal(compression, deflate) { +// // deflate rfc 1951 implementation +// return flate.NewReader(bytes.NewReader(bodyBytes)), nil +// } +// // got error while body decompression +// return nil, err +// } +// // body has been successfully uncompressed +// return io.NopCloser(bytes.NewReader(body)), nil +// } +// } +// // body compression schema not supported +// return nil, fasthttp.ErrContentEncodingUnsupported +// } +// +// // body without compression +// return io.NopCloser(bytes.NewReader(bodyBytes)), nil +//} + +// GetDecompressedResponseBody function returns the Reader of the decompressed response body +func GetDecompressedResponseBody(resp *fasthttp.Response, contentEncoding string) (io.ReadCloser, error) { + + bodyBytes := resp.Body() + + if contentEncoding != "" { + if slices.Contains(supportedEncodings, contentEncoding) { + var body []byte + var err error + if body, err = resp.BodyUncompressed(); err != nil { + if errors.Is(zlib.ErrHeader, err) && contentEncoding == "deflate" { + // deflate rfc 1951 implementation + return flate.NewReader(bytes.NewReader(bodyBytes)), nil } - // body has been successfully uncompressed - return io.NopCloser(bytes.NewReader(body)), nil + // got error while body decompression + return nil, err } + // body has been successfully uncompressed + return io.NopCloser(bytes.NewReader(body)), nil + } + // body compression schema not supported + return nil, fasthttp.ErrContentEncodingUnsupported + } + + // body without compression + return io.NopCloser(bytes.NewReader(bodyBytes)), nil +} + +// GetDecompressedRequestBody function returns the Reader of the decompressed request body +func GetDecompressedRequestBody(req *fasthttp.Request, contentEncoding string) (io.ReadCloser, error) { + + bodyBytes := req.Body() + + if contentEncoding != "" { + if slices.Contains(supportedEncodings, contentEncoding) { + var body []byte + var err error + if body, err = req.BodyUncompressed(); err != nil { + if errors.Is(zlib.ErrHeader, err) && contentEncoding == "deflate" { + // deflate rfc 1951 implementation + return flate.NewReader(bytes.NewReader(bodyBytes)), nil + } + // got error while body decompression + return nil, err + } + // body has been successfully uncompressed + return io.NopCloser(bytes.NewReader(body)), nil } // body compression schema not supported return nil, fasthttp.ErrContentEncodingUnsupported