Skip to content

Commit

Permalink
Merge pull request #72 from k1LoW/skip-circular-reference-check
Browse files Browse the repository at this point in the history
Add SkipCircularReferenceCheck option
  • Loading branch information
k1LoW authored Aug 31, 2024
2 parents 229f8f0 + 6ec7618 commit 531d765
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 21 deletions.
19 changes: 16 additions & 3 deletions httpstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ type Router struct {
useTLS bool
cacert, cert, key []byte
clientCacert, clientCert, clientKey []byte
openAPI3Doc *libopenapi.Document
openAPI3Validator *validator.Validator
openAPI3Doc libopenapi.Document
openAPI3Validator validator.Validator
skipValidateRequest bool
skipValidateResponse bool
mu sync.RWMutex
Expand Down Expand Up @@ -113,6 +113,19 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func NewRouter(t TB, opts ...Option) *Router {
t.Helper()
c := &config{}

// Set skipCircularReferenceCheck first
for _, opt := range opts {
tmp := &config{}
_ = opt(tmp)
if tmp.skipCircularReferenceCheck {
if err := opt(c); err != nil {
t.Fatal(err)
}
break
}
}

for _, opt := range opts {
if err := opt(c); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -469,7 +482,7 @@ func (m *matcher) ResponseExample(opts ...responseExampleOption) {
return
}
}
doc := *m.router.openAPI3Doc
doc := m.router.openAPI3Doc
v3m, errs := doc.BuildV3Model()
if errs != nil {
m.router.t.Errorf("failed to build OpenAPI v3 model: %v", errors.Join(errs...))
Expand Down
12 changes: 6 additions & 6 deletions openapi3.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ func (rt *Router) setOpenApi3Vaildator() error {
}
mw := func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
v := *rt.openAPI3Validator
v := rt.openAPI3Validator
if !rt.skipValidateRequest {
_, errs := v.ValidateHttpRequest(r)
if len(errs) > 0 {
{
// renew validator (workaround)
// ref: https://github.com/k1LoW/runn/issues/882
vv, errrs := validator.NewValidator(*rt.openAPI3Doc)
vv, errrs := validator.NewValidator(rt.openAPI3Doc)
if len(errrs) > 0 {
rt.t.Errorf("failed to renew validator: %v", errors.Join(errrs...))
return
}
rt.openAPI3Validator = &vv
v = *rt.openAPI3Validator
rt.openAPI3Validator = vv
v = rt.openAPI3Validator
}
var err error
for _, e := range errs {
Expand All @@ -97,12 +97,12 @@ func (rt *Router) setOpenApi3Vaildator() error {
{
// renew validator (workaround)
// ref: https://github.com/k1LoW/runn/issues/882
vv, errrs := validator.NewValidator(*rt.openAPI3Doc)
vv, errrs := validator.NewValidator(rt.openAPI3Doc)
if len(errrs) > 0 {
rt.t.Errorf("failed to renew validator: %v", errors.Join(errrs...))
return
}
rt.openAPI3Validator = &vv
rt.openAPI3Validator = vv
}
var err error
for _, e := range errs {
Expand Down
39 changes: 36 additions & 3 deletions openapi3_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpstub

import (
"fmt"
"net/http"
"strings"
"testing"
Expand All @@ -25,12 +26,11 @@ func TestOpenAPI3(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockTB := mock_httpstub.NewMockTB(ctrl)
mockTB.EXPECT().Helper()
mockTB.EXPECT().Helper().AnyTimes()
if tt.wantErr {
mockTB.EXPECT().Errorf(gomock.Any(), gomock.Any())
}
rt := NewRouter(t, OpenApi3("testdata/openapi3.yml"))
rt.t = mockTB
rt := NewRouter(mockTB, OpenApi3("testdata/openapi3.yml"))
rt.Method(http.MethodPost).Path("/api/v1/users").Header("Content-Type", "application/json").ResponseString(http.StatusCreated, `{"name":"alice"}`)
// invalid response
rt.Method(http.MethodGet).Path("/api/v1/users").Header("Content-Type", "application/json").ResponseString(http.StatusBadRequest, `{"invalid":"data"}`)
Expand All @@ -46,6 +46,39 @@ func TestOpenAPI3(t *testing.T) {
}
}

func TestSkipCircularReferenceCheck(t *testing.T) {
tests := []struct {
skipCircularReferenceCheck bool
wantErr bool
}{
{false, true},
{true, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("skipCircularReferenceCheck=%v", tt.skipCircularReferenceCheck), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockTB := mock_httpstub.NewMockTB(ctrl)
mockTB.EXPECT().Helper().AnyTimes()
if tt.wantErr {
mockTB.EXPECT().Fatal(gomock.Any())
}
rt := NewRouter(mockTB, OpenApi3("testdata/openapi3-circular-references.yml"), SkipCircularReferenceCheck(tt.skipCircularReferenceCheck), SkipValidateResponse(true))
// invalid response
rt.Method(http.MethodGet).Path("/api/hello").Header("Content-Type", "application/json").ResponseString(http.StatusOK, `{"rows":[]}`)
ts := rt.Server()
t.Cleanup(func() {
ts.Close()
})
tc := ts.Client()
req := newRequest(t, http.MethodGet, "/api/hello", "")
if _, err := tc.Do(req); err != nil {
t.Error(err)
}
})
}
}

func newRequest(t *testing.T, method string, path string, body string) *http.Request {
t.Helper()
req, err := http.NewRequest(method, path, strings.NewReader(body))
Expand Down
33 changes: 24 additions & 9 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ type config struct {
useTLS bool
cacert, cert, key []byte
clientCacert, clientCert, clientKey []byte
openAPI3Doc *libopenapi.Document
openAPI3Validator *validator.Validator
openAPI3Doc libopenapi.Document
openAPI3Validator validator.Validator
skipValidateRequest bool
skipValidateResponse bool
skipCircularReferenceCheck bool
}

type Option func(*config) error
Expand All @@ -29,8 +30,9 @@ func OpenApi3(l string) Option {
return func(c *config) error {
var doc libopenapi.Document
dc := &datamodel.DocumentConfiguration{
AllowFileReferences: true,
AllowRemoteReferences: true,
AllowFileReferences: true,
AllowRemoteReferences: true,
SkipCircularReferenceCheck: c.skipCircularReferenceCheck,
}
switch {
case strings.HasPrefix(l, "https://") || strings.HasPrefix(l, "http://"):
Expand Down Expand Up @@ -68,16 +70,21 @@ func OpenApi3(l string) Option {
}
return err
}
c.openAPI3Doc = &doc
c.openAPI3Validator = &v
c.openAPI3Doc = doc
c.openAPI3Validator = v
return nil
}
}

// OpenApi3FromData sets OpenAPI Document from bytes
func OpenApi3FromData(b []byte) Option {
return func(c *config) error {
doc, err := libopenapi.NewDocument(b)
dc := &datamodel.DocumentConfiguration{
AllowFileReferences: true,
AllowRemoteReferences: true,
SkipCircularReferenceCheck: c.skipCircularReferenceCheck,
}
doc, err := libopenapi.NewDocumentWithConfiguration(b, dc)
if err != nil {
return err
}
Expand All @@ -92,8 +99,8 @@ func OpenApi3FromData(b []byte) Option {
}
return err
}
c.openAPI3Doc = &doc
c.openAPI3Validator = &v
c.openAPI3Doc = doc
c.openAPI3Validator = v
return nil
}
}
Expand All @@ -114,6 +121,14 @@ func SkipValidateResponse(skip bool) Option {
}
}

// SkipCircularReferenceCheck sets whether to skip circular reference check in OpenAPI Document.
func SkipCircularReferenceCheck(skip bool) Option {
return func(c *config) error {
c.skipCircularReferenceCheck = skip
return nil
}
}

// UseTLS enable TLS
func UseTLS() Option {
return func(c *config) error {
Expand Down
39 changes: 39 additions & 0 deletions testdata/openapi3-circular-references.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
openapi: 3.1.0
info:
title: For debugging
version: 1.0.0
paths:
/api/hello:
get:
responses:
"200":
content:
application/json:
schema:
$ref: "#/components/schemas/Response"
description: Debugging
components:
schemas:
Response:
type: object
properties:
rows:
type: array
items:
$ref: "#/components/schemas/Row"
required:
- rows
Row:
type: object
properties:
name:
description: Name of the row
type: string
rows:
description: A collection of row
type: array
items:
$ref: "#/components/schemas/Row"
required:
- name
- rows

0 comments on commit 531d765

Please sign in to comment.