From daadcdf307bfa8d52c7f53c67136c67dc30c3c50 Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:54:07 -0700 Subject: [PATCH] feat: override status codes --- cmd/muxt/testdata/generate/form.txtar | 10 ++--- generate.go | 8 ++-- generate_test.go | 3 -- name.go | 26 +++++++++---- name_internal_test.go | 16 ++++++++ name_test.go | 54 +++++++++++++++------------ 6 files changed, 73 insertions(+), 44 deletions(-) diff --git a/cmd/muxt/testdata/generate/form.txtar b/cmd/muxt/testdata/generate/form.txtar index c1f999d..29ef7f6 100644 --- a/cmd/muxt/testdata/generate/form.txtar +++ b/cmd/muxt/testdata/generate/form.txtar @@ -5,7 +5,7 @@ cat template_routes.go exec go test -cover -- template.gohtml -- -{{define "POST / Method(form)" }}{{end}} +{{define "POST / Method(form) 201" }}{{end}} -- go.mod -- module server @@ -26,8 +26,8 @@ var formHTML embed.FS var templates = template.Must(template.ParseFS(formHTML, "*")) type Form struct { - Count []int `json:"count"` - Str string `name:"some-string" json:"str"` + Count []int `json:"count"` + Str string `name:"some-string" json:"str"` } type T struct { @@ -78,8 +78,8 @@ func Test(t *testing.T) { res := rec.Result() - if res.StatusCode != http.StatusOK { - t.Error("expected OK") + if res.StatusCode != http.StatusCreated { + t.Error("exp", http.StatusText(http.StatusCreated), "got", http.StatusText(res.StatusCode)) } body, err := io.ReadAll(res.Body) diff --git a/generate.go b/generate.go index 1dac618..3f3058f 100644 --- a/generate.go +++ b/generate.go @@ -137,7 +137,7 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.FuncLit, []*ast.ImportSpec, error) { if method == nil { - return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil + return def.httpRequestReceiverTemplateHandlerFunc(def.statusCode), nil, nil } lit := &ast.FuncLit{ Type: httpHandlerFuncType(), @@ -340,7 +340,7 @@ func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.F } else { lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) } - lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(dataVarIdent), writeHeader)) + lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(httpPackageIdent, def.statusCode), ast.NewIdent(dataVarIdent), writeHeader)) return lit, imports, nil } @@ -855,10 +855,10 @@ func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *as }} } -func (def TemplateName) httpRequestReceiverTemplateHandlerFunc() *ast.FuncLit { +func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(statusCode int) *ast.FuncLit { return &ast.FuncLit{ Type: httpHandlerFuncType(), - Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, + Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(httpPackageIdent, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, } } diff --git a/generate_test.go b/generate_test.go index b6337a2..6898882 100644 --- a/generate_test.go +++ b/generate_test.go @@ -1462,9 +1462,6 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo if tt.ExpectedError == "" { assert.NoError(t, err) assert.Equal(t, tt.ExpectedFile, out) - if t.Failed() { - t.Logf(out) - } } else { assert.ErrorContains(t, err, tt.ExpectedError) } diff --git a/name.go b/name.go index a92d9b9..7e0d1b8 100644 --- a/name.go +++ b/name.go @@ -10,6 +10,7 @@ import ( "net/http" "regexp" "slices" + "strconv" "strings" "github.com/crhntr/muxt/internal/source" @@ -40,6 +41,7 @@ type TemplateName struct { name string method, host, path, endpoint string handler string + statusCode int fun *ast.Ident call *ast.CallExpr @@ -57,13 +59,21 @@ func newTemplate(in string) (TemplateName, error, bool) { } matches := templateNameMux.FindStringSubmatch(in) p := TemplateName{ - name: in, - method: matches[templateNameMux.SubexpIndex("method")], - host: matches[templateNameMux.SubexpIndex("host")], - path: matches[templateNameMux.SubexpIndex("path")], - handler: strings.TrimSpace(matches[templateNameMux.SubexpIndex("handler")]), - endpoint: matches[templateNameMux.SubexpIndex("endpoint")], - fileSet: token.NewFileSet(), + name: in, + method: matches[templateNameMux.SubexpIndex("method")], + host: matches[templateNameMux.SubexpIndex("host")], + path: matches[templateNameMux.SubexpIndex("path")], + handler: strings.TrimSpace(matches[templateNameMux.SubexpIndex("handler")]), + endpoint: matches[templateNameMux.SubexpIndex("endpoint")], + fileSet: token.NewFileSet(), + statusCode: http.StatusOK, + } + if s := matches[templateNameMux.SubexpIndex("code")]; s != "" { + code, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil { + return TemplateName{}, fmt.Errorf("failed to parse status code: %w", err), true + } + p.statusCode = code } switch p.method { @@ -86,7 +96,7 @@ func newTemplate(in string) (TemplateName, error, bool) { var ( pathSegmentPattern = regexp.MustCompile(`/\{([^}]*)}`) - templateNameMux = regexp.MustCompile(`^(?P(((?P[A-Z]+)\s+)?)(?P([^/])*)(?P(/(\S)*)))(?P.*)$`) + templateNameMux = regexp.MustCompile(`^(?P(((?P[A-Z]+)\s+)?)(?P([^/])*)(?P(/(\S)*)))(?P\PL+.*\(.*\))?(?P\s\d+)?$`) ) func (def TemplateName) parsePathValueNames() ([]string, error) { diff --git a/name_internal_test.go b/name_internal_test.go index 621d22b..53521bd 100644 --- a/name_internal_test.go +++ b/name_internal_test.go @@ -244,6 +244,22 @@ func TestNewTemplateName(t *testing.T) { assert.ErrorContains(t, err, `forbidden repeated path parameter names: found at least 2 path parameters with name "name"`) }, }, + { + Name: "with status code", + In: "POST / 202", + ExpMatch: true, + TemplateName: func(t *testing.T, pat TemplateName) { + assert.Equal(t, http.StatusAccepted, pat.statusCode) + }, + }, + { + Name: "with code", + In: "POST /", + ExpMatch: true, + TemplateName: func(t *testing.T, pat TemplateName) { + assert.Equal(t, http.StatusOK, pat.statusCode) + }, + }, } { t.Run(tt.Name, func(t *testing.T) { pat, err, match := NewTemplateName(tt.In) diff --git a/name_test.go b/name_test.go index ac89f46..dbce63e 100644 --- a/name_test.go +++ b/name_test.go @@ -25,46 +25,52 @@ func TestTemplateNames(t *testing.T) { func TestPattern_parseHandler(t *testing.T) { for _, tt := range []struct { - Name string - In string - ExpErr string + Name string + In string + ExpErr string + ExpMatch bool }{ { - Name: "no arg post", - In: "GET / F()", + Name: "no arg post", + In: "GET / F()", + ExpMatch: true, }, { - Name: "no arg get", - In: "POST / F()", + Name: "no arg get", + In: "POST / F()", + ExpMatch: true, }, { - Name: "int as handler", - In: "POST / 1", - ExpErr: "expected call, got: 1", + Name: "float64 as handler", + In: "POST / 1.2", + ExpMatch: false, }, { - Name: "not an expression", - In: "GET / package main", - ExpErr: "failed to parse handler expression: ", + Name: "not an expression", + In: "GET / package main", + ExpMatch: false, }, { - Name: "function literal", - In: "GET / func() {} ()", - ExpErr: "expected function identifier", + Name: "function literal", + In: "GET / func() {} ()", + ExpMatch: true, + ExpErr: "expected function identifier", }, { - Name: "call ellipsis", - In: "GET /{fileName} F(fileName...)", - ExpErr: "unexpected ellipsis", + Name: "call ellipsis", + In: "GET /{fileName} F(fileName...)", + ExpMatch: true, + ExpErr: "unexpected ellipsis", }, } { t.Run(tt.Name, func(t *testing.T) { _, err, ok := muxt.NewTemplateName(tt.In) - require.True(t, ok) - if tt.ExpErr != "" { - assert.ErrorContains(t, err, tt.ExpErr) - } else { - assert.NoError(t, err) + if assert.Equal(t, tt.ExpMatch, ok) { + if tt.ExpErr != "" { + assert.ErrorContains(t, err, tt.ExpErr) + } else { + assert.NoError(t, err) + } } }) }