From daadcdf307bfa8d52c7f53c67136c67dc30c3c50 Mon Sep 17 00:00:00 2001
From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com>
Date: Fri, 30 Aug 2024 09:54:07 -0700
Subject: [PATCH] feat: override status codes
---
cmd/muxt/testdata/generate/form.txtar | 10 ++---
generate.go | 8 ++--
generate_test.go | 3 --
name.go | 26 +++++++++----
name_internal_test.go | 16 ++++++++
name_test.go | 54 +++++++++++++++------------
6 files changed, 73 insertions(+), 44 deletions(-)
diff --git a/cmd/muxt/testdata/generate/form.txtar b/cmd/muxt/testdata/generate/form.txtar
index c1f999d..29ef7f6 100644
--- a/cmd/muxt/testdata/generate/form.txtar
+++ b/cmd/muxt/testdata/generate/form.txtar
@@ -5,7 +5,7 @@ cat template_routes.go
exec go test -cover
-- template.gohtml --
-{{define "POST / Method(form)" }}{{end}}
+{{define "POST / Method(form) 201" }}{{end}}
-- go.mod --
module server
@@ -26,8 +26,8 @@ var formHTML embed.FS
var templates = template.Must(template.ParseFS(formHTML, "*"))
type Form struct {
- Count []int `json:"count"`
- Str string `name:"some-string" json:"str"`
+ Count []int `json:"count"`
+ Str string `name:"some-string" json:"str"`
}
type T struct {
@@ -78,8 +78,8 @@ func Test(t *testing.T) {
res := rec.Result()
- if res.StatusCode != http.StatusOK {
- t.Error("expected OK")
+ if res.StatusCode != http.StatusCreated {
+ t.Error("exp", http.StatusText(http.StatusCreated), "got", http.StatusText(res.StatusCode))
}
body, err := io.ReadAll(res.Body)
diff --git a/generate.go b/generate.go
index 1dac618..3f3058f 100644
--- a/generate.go
+++ b/generate.go
@@ -137,7 +137,7 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm
func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.FuncLit, []*ast.ImportSpec, error) {
if method == nil {
- return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil
+ return def.httpRequestReceiverTemplateHandlerFunc(def.statusCode), nil, nil
}
lit := &ast.FuncLit{
Type: httpHandlerFuncType(),
@@ -340,7 +340,7 @@ func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.F
} else {
lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}})
}
- lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(dataVarIdent), writeHeader))
+ lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(httpPackageIdent, def.statusCode), ast.NewIdent(dataVarIdent), writeHeader))
return lit, imports, nil
}
@@ -855,10 +855,10 @@ func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *as
}}
}
-func (def TemplateName) httpRequestReceiverTemplateHandlerFunc() *ast.FuncLit {
+func (def TemplateName) httpRequestReceiverTemplateHandlerFunc(statusCode int) *ast.FuncLit {
return &ast.FuncLit{
Type: httpHandlerFuncType(),
- Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}},
+ Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(httpPackageIdent, statusCode), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}},
}
}
diff --git a/generate_test.go b/generate_test.go
index b6337a2..6898882 100644
--- a/generate_test.go
+++ b/generate_test.go
@@ -1462,9 +1462,6 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo
if tt.ExpectedError == "" {
assert.NoError(t, err)
assert.Equal(t, tt.ExpectedFile, out)
- if t.Failed() {
- t.Logf(out)
- }
} else {
assert.ErrorContains(t, err, tt.ExpectedError)
}
diff --git a/name.go b/name.go
index a92d9b9..7e0d1b8 100644
--- a/name.go
+++ b/name.go
@@ -10,6 +10,7 @@ import (
"net/http"
"regexp"
"slices"
+ "strconv"
"strings"
"github.com/crhntr/muxt/internal/source"
@@ -40,6 +41,7 @@ type TemplateName struct {
name string
method, host, path, endpoint string
handler string
+ statusCode int
fun *ast.Ident
call *ast.CallExpr
@@ -57,13 +59,21 @@ func newTemplate(in string) (TemplateName, error, bool) {
}
matches := templateNameMux.FindStringSubmatch(in)
p := TemplateName{
- name: in,
- method: matches[templateNameMux.SubexpIndex("method")],
- host: matches[templateNameMux.SubexpIndex("host")],
- path: matches[templateNameMux.SubexpIndex("path")],
- handler: strings.TrimSpace(matches[templateNameMux.SubexpIndex("handler")]),
- endpoint: matches[templateNameMux.SubexpIndex("endpoint")],
- fileSet: token.NewFileSet(),
+ name: in,
+ method: matches[templateNameMux.SubexpIndex("method")],
+ host: matches[templateNameMux.SubexpIndex("host")],
+ path: matches[templateNameMux.SubexpIndex("path")],
+ handler: strings.TrimSpace(matches[templateNameMux.SubexpIndex("handler")]),
+ endpoint: matches[templateNameMux.SubexpIndex("endpoint")],
+ fileSet: token.NewFileSet(),
+ statusCode: http.StatusOK,
+ }
+ if s := matches[templateNameMux.SubexpIndex("code")]; s != "" {
+ code, err := strconv.Atoi(strings.TrimSpace(s))
+ if err != nil {
+ return TemplateName{}, fmt.Errorf("failed to parse status code: %w", err), true
+ }
+ p.statusCode = code
}
switch p.method {
@@ -86,7 +96,7 @@ func newTemplate(in string) (TemplateName, error, bool) {
var (
pathSegmentPattern = regexp.MustCompile(`/\{([^}]*)}`)
- templateNameMux = regexp.MustCompile(`^(?P(((?P[A-Z]+)\s+)?)(?P([^/])*)(?P(/(\S)*)))(?P.*)$`)
+ templateNameMux = regexp.MustCompile(`^(?P(((?P[A-Z]+)\s+)?)(?P([^/])*)(?P(/(\S)*)))(?P\PL+.*\(.*\))?(?P\s\d+)?$`)
)
func (def TemplateName) parsePathValueNames() ([]string, error) {
diff --git a/name_internal_test.go b/name_internal_test.go
index 621d22b..53521bd 100644
--- a/name_internal_test.go
+++ b/name_internal_test.go
@@ -244,6 +244,22 @@ func TestNewTemplateName(t *testing.T) {
assert.ErrorContains(t, err, `forbidden repeated path parameter names: found at least 2 path parameters with name "name"`)
},
},
+ {
+ Name: "with status code",
+ In: "POST / 202",
+ ExpMatch: true,
+ TemplateName: func(t *testing.T, pat TemplateName) {
+ assert.Equal(t, http.StatusAccepted, pat.statusCode)
+ },
+ },
+ {
+ Name: "with code",
+ In: "POST /",
+ ExpMatch: true,
+ TemplateName: func(t *testing.T, pat TemplateName) {
+ assert.Equal(t, http.StatusOK, pat.statusCode)
+ },
+ },
} {
t.Run(tt.Name, func(t *testing.T) {
pat, err, match := NewTemplateName(tt.In)
diff --git a/name_test.go b/name_test.go
index ac89f46..dbce63e 100644
--- a/name_test.go
+++ b/name_test.go
@@ -25,46 +25,52 @@ func TestTemplateNames(t *testing.T) {
func TestPattern_parseHandler(t *testing.T) {
for _, tt := range []struct {
- Name string
- In string
- ExpErr string
+ Name string
+ In string
+ ExpErr string
+ ExpMatch bool
}{
{
- Name: "no arg post",
- In: "GET / F()",
+ Name: "no arg post",
+ In: "GET / F()",
+ ExpMatch: true,
},
{
- Name: "no arg get",
- In: "POST / F()",
+ Name: "no arg get",
+ In: "POST / F()",
+ ExpMatch: true,
},
{
- Name: "int as handler",
- In: "POST / 1",
- ExpErr: "expected call, got: 1",
+ Name: "float64 as handler",
+ In: "POST / 1.2",
+ ExpMatch: false,
},
{
- Name: "not an expression",
- In: "GET / package main",
- ExpErr: "failed to parse handler expression: ",
+ Name: "not an expression",
+ In: "GET / package main",
+ ExpMatch: false,
},
{
- Name: "function literal",
- In: "GET / func() {} ()",
- ExpErr: "expected function identifier",
+ Name: "function literal",
+ In: "GET / func() {} ()",
+ ExpMatch: true,
+ ExpErr: "expected function identifier",
},
{
- Name: "call ellipsis",
- In: "GET /{fileName} F(fileName...)",
- ExpErr: "unexpected ellipsis",
+ Name: "call ellipsis",
+ In: "GET /{fileName} F(fileName...)",
+ ExpMatch: true,
+ ExpErr: "unexpected ellipsis",
},
} {
t.Run(tt.Name, func(t *testing.T) {
_, err, ok := muxt.NewTemplateName(tt.In)
- require.True(t, ok)
- if tt.ExpErr != "" {
- assert.ErrorContains(t, err, tt.ExpErr)
- } else {
- assert.NoError(t, err)
+ if assert.Equal(t, tt.ExpMatch, ok) {
+ if tt.ExpErr != "" {
+ assert.ErrorContains(t, err, tt.ExpErr)
+ } else {
+ assert.NoError(t, err)
+ }
}
})
}