diff --git a/README.md b/README.md
index acb3b0a..4a511eb 100644
--- a/README.md
+++ b/README.md
@@ -2,16 +2,35 @@
This is especially helpful when you are writing HTMX.
-## Example
+Given the following files
-The "define" blocks in the following template register handlers with the server mux.
+### main.go
-The http method, http host, and path semantics match those of in the HTTP package.
+```go
+package main
-This library extends this to add custom data handler invocations see "PATCH /fruits/{fruit}". It is configured to call EditRow on template parse time provided receiver.
+import (
+ "html/template"
+ "log"
+ "net/http"
+)
-When no handler method is specified in the "declare" string (as is the case with "GET /fruits/{fruit}/edit" in the example), the template receives the *http.Request.
+//go:embed *.gohtml
+var templateSource embed.FS
+var templates = template.Must(template.ParseFS(templateSource, "*"))
+
+type Backend struct {}
+
+func main() {
+ mux := http.NewServeMux()
+ muxt := Routes(mux, templates)
+
+
+}
+```
+
+### index.gohtml
```html
@@ -36,7 +55,6 @@ When no handler method is specified in the "declare" string (as is the case with
- {{- range . -}}
{{- block "fruit row" . -}}
{{ .Fruit }}
@@ -45,9 +63,11 @@ When no handler method is specified in the "declare" string (as is the case with
{{- end -}}
+
+ {{- define "GET /{} List(ctx)" -}}
+ {{template "index.gohtml" .}}
{{- end -}}
-
{{- define "GET /fruits/{fruit}/edit" -}}
{{ .PathValue "fruit" }}
diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go
index 9b59293..0a6d62f 100644
--- a/cmd/muxt/generate.go
+++ b/cmd/muxt/generate.go
@@ -27,6 +27,7 @@ type Generate struct {
templatesVariable string
outputFilename string
routesFunction string
+ receiverIdent string
}
func (g Generate) ImportReceiverMethods(tp, method string) (*ast.FuncType, []*ast.ImportSpec, bool) {
@@ -43,9 +44,11 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
flagSet.StringVar(&config.templatesVariable, "templates-variable", muxt.DefaultTemplatesVariableName, "templates variable name")
flagSet.StringVar(&config.outputFilename, "output-file", "template_routes.go", "file name of generated output")
flagSet.StringVar(&config.routesFunction, "routes-func", muxt.DefaultRoutesFunctionName, "file name of generated output")
+ flagSet.StringVar(&config.receiverIdent, "receiver", "", "static receiver type identifier")
if err := flagSet.Parse(args); err != nil {
return err
}
+ _ = os.Remove(filepath.Join(workingDirectory, config.outputFilename))
list, err := packages.Load(&packages.Config{
Mode: packages.NeedFiles | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles,
Dir: workingDirectory,
@@ -72,7 +75,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
return err
}
out := log.New(stdout, "", 0)
- s, err := muxt.Generate(patterns, config.goPackage, config.templatesVariable, config.routesFunction, "", config.Package.Fset, config.Package.Syntax, config.Package.Syntax, out)
+ s, err := muxt.Generate(patterns, config.goPackage, config.templatesVariable, config.routesFunction, config.receiverIdent, config.Package.Fset, config.Package.Syntax, config.Package.Syntax, out)
if err != nil {
return err
}
diff --git a/cmd/muxt/main.go b/cmd/muxt/main.go
index bb407b9..32275ed 100644
--- a/cmd/muxt/main.go
+++ b/cmd/muxt/main.go
@@ -28,7 +28,7 @@ func command(wd string, args []string, getEnv func(string) string, stdout, stder
func handleError(err error) int {
if err != nil {
- _, _ = os.Stderr.WriteString(err.Error())
+ _, _ = os.Stderr.WriteString(err.Error() + "\n")
return 1
}
return 0
diff --git a/example/index.gohtml b/example/index.gohtml
new file mode 100644
index 0000000..3a34dfa
--- /dev/null
+++ b/example/index.gohtml
@@ -0,0 +1,76 @@
+
+
+{{block "head" "example"}}
+
+
+ {{.}}
+
+
+
+
+
+{{end}}
+
+
+
+
+
+ Fruit
+ Count
+
+
+
+
+ {{- define "fruit row" -}}
+
+ {{ .Name }}
+ {{ .Value }}
+
+ {{- end -}}
+
+ {{range .}}
+ {{template "fruit row" .}}
+ {{end}}
+
+ {{- define "GET /{$} List(ctx)" -}}
+ {{template "index.gohtml" .}}
+ {{- end -}}
+
+ {{- define "GET /fruits/{fruit}/edit GetFormEditRow(fruit)" -}}
+
+ {{ .Row.Name }}
+
+
+ {{.Error}}
+
+
+ {{- end -}}
+
+ {{- define "PATCH /fruits/{fruit} SubmitFormEditRow(request, fruit)" }}
+ {{- if .Error -}}
+ {{template "GET /fruits/{fruit}/edit GetFormEditRow(fruit)" .}}
+ {{- else -}}
+ {{template "fruit row" .Row}}
+ {{- end -}}
+ {{ end -}}
+
+
+
+
+
+
+
+{{define "GET /help"}}
+
+
+{{template "head" "Help"}}
+
+
+ Hello, help!
+
+
+
+{{end}}
\ No newline at end of file
diff --git a/example/main.go b/example/main.go
new file mode 100644
index 0000000..6fe186b
--- /dev/null
+++ b/example/main.go
@@ -0,0 +1,70 @@
+package main
+
+import (
+ "context"
+ "embed"
+ "fmt"
+ "html/template"
+ "log"
+ "net/http"
+ "strconv"
+)
+
+//go:embed *.gohtml
+var templateSource embed.FS
+
+var templates = template.Must(template.ParseFS(templateSource, "*"))
+
+type Backend struct {
+ data []Row
+}
+
+type EditRowPage struct {
+ Row Row
+ Error error
+}
+
+func (b *Backend) SubmitFormEditRow(request *http.Request, fruit string) EditRowPage {
+ count, err := strconv.Atoi(request.FormValue("count"))
+ if err != nil {
+ return EditRowPage{Error: err, Row: Row{Name: fruit}}
+ }
+ for i := range b.data {
+ if b.data[i].Name == fruit {
+ b.data[i].Value = count
+ return EditRowPage{Error: nil, Row: b.data[i]}
+ }
+ }
+ return EditRowPage{Error: fmt.Errorf("fruit not found")}
+}
+
+func (b *Backend) GetFormEditRow(fruit string) EditRowPage {
+ for i := range b.data {
+ if b.data[i].Name == fruit {
+ return EditRowPage{Error: nil, Row: b.data[i]}
+ }
+ }
+ return EditRowPage{Error: fmt.Errorf("fruit not found")}
+}
+
+type Row struct {
+ Name string
+ Value int
+}
+
+func (b *Backend) List(_ context.Context) []Row { return b.data }
+
+//go:generate muxt generate --receiver Backend
+
+func main() {
+ backend := &Backend{
+ data: []Row{
+ {Name: "Peach", Value: 10},
+ {Name: "Plum", Value: 20},
+ {Name: "Pineapple", Value: 2},
+ },
+ }
+ mux := http.NewServeMux()
+ Routes(mux, backend)
+ log.Fatal(http.ListenAndServe(":8080", mux))
+}
diff --git a/example/template_routes.go b/example/template_routes.go
new file mode 100644
index 0000000..a7ba808
--- /dev/null
+++ b/example/template_routes.go
@@ -0,0 +1,46 @@
+// Code generated by muxt. DO NOT EDIT.
+
+package main
+
+import (
+ "context"
+ "net/http"
+ "bytes"
+ "html/template"
+)
+
+type RoutesReceiver interface {
+ SubmitFormEditRow(request *http.Request, fruit string) EditRowPage
+ GetFormEditRow(fruit string) EditRowPage
+ List(_ context.Context) []Row
+}
+
+func Routes(mux *http.ServeMux, receiver RoutesReceiver) {
+ mux.HandleFunc("PATCH /fruits/{fruit}", func(response http.ResponseWriter, request *http.Request) {
+ fruit := request.PathValue("fruit")
+ data := receiver.SubmitFormEditRow(request, fruit)
+ execute(response, request, templates.Lookup("PATCH /fruits/{fruit} SubmitFormEditRow(request, fruit)"), http.StatusOK, data)
+ })
+ mux.HandleFunc("GET /fruits/{fruit}/edit", func(response http.ResponseWriter, request *http.Request) {
+ fruit := request.PathValue("fruit")
+ data := receiver.GetFormEditRow(fruit)
+ execute(response, request, templates.Lookup("GET /fruits/{fruit}/edit GetFormEditRow(fruit)"), http.StatusOK, data)
+ })
+ mux.HandleFunc("GET /help", func(response http.ResponseWriter, request *http.Request) {
+ execute(response, request, templates.Lookup("GET /help"), http.StatusOK, request)
+ })
+ mux.HandleFunc("GET /{$}", func(response http.ResponseWriter, request *http.Request) {
+ ctx := request.Context()
+ data := receiver.List(ctx)
+ execute(response, request, templates.Lookup("GET /{$} List(ctx)"), http.StatusOK, data)
+ })
+}
+func execute(response http.ResponseWriter, request *http.Request, t *template.Template, code int, data any) {
+ buf := bytes.NewBuffer(nil)
+ if err := t.Execute(buf, data); err != nil {
+ http.Error(response, err.Error(), http.StatusOK)
+ return
+ }
+ response.WriteHeader(code)
+ _, _ = buf.WriteTo(response)
+}
diff --git a/generate.go b/generate.go
index 35630b4..134e8f9 100644
--- a/generate.go
+++ b/generate.go
@@ -77,11 +77,11 @@ func Generate(patterns []Pattern, packageName, templatesVariableName, routesFunc
Type: method,
})
}
- handlerFunc, handerSignatureImports, err := pattern.funcLit(templatesVariableName, method)
+ handlerFunc, methodImports, err := pattern.funcLit(templatesVariableName, method)
if err != nil {
return "", err
}
- imports = source.SortImports(append(imports, handerSignatureImports...))
+ imports = source.SortImports(append(imports, methodImports...))
routes.Body.List = append(routes.Body.List, pattern.callHandleFunc(handlerFunc))
log.Printf("%s has route for %s", routesFunctionName, pattern.String())
}
@@ -163,12 +163,12 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType)
}
}
+ const dataVarIdent = "data"
if method != nil && len(method.Results.List) > 1 {
- dataVar := ast.NewIdent(dataVarIdent)
errVar := ast.NewIdent("err")
lit.Body.List = append(lit.Body.List,
- &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVar.Name), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}},
+ &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}},
&ast.IfStmt{
Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")},
Body: &ast.BlockStmt{
@@ -194,11 +194,10 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType)
},
},
)
+ } else {
+ lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}})
}
-
- data := ast.NewIdent(dataVarIdent)
- lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(data.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}})
- lit.Body.List = append(lit.Body.List, def.executeCall(ast.NewIdent(templatesVariableIdent), httpStatusCode(httpStatusCode200Ident), data))
+ lit.Body.List = append(lit.Body.List, def.executeCall(ast.NewIdent(templatesVariableIdent), httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent)))
return lit, imports, nil
}
@@ -422,7 +421,16 @@ func (def Pattern) httpRequestReceiverTemplateHandlerFunc(templatesVariableName
}
func (def Pattern) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool {
- return funcDecl.Name.Name == def.fun.Name && funcDecl.Recv != nil && len(funcDecl.Recv.List) == 1 && funcDecl.Recv.List[0].Type.(*ast.Ident).Name == receiverTypeIdent
+ if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != def.fun.Name || funcDecl.Recv == nil && len(funcDecl.Recv.List) < 1 {
+ return false
+ }
+ exp := funcDecl.Recv.List[0].Type
+ if star, ok := exp.(*ast.StarExpr); ok {
+ exp = star.X
+ }
+ ident, ok := exp.(*ast.Ident)
+ return ok && ident.Name == receiverTypeIdent
+
}
func executeFuncDecl() *ast.FuncDecl {
diff --git a/reflect.go b/reflect.go
deleted file mode 100644
index 9b61299..0000000
--- a/reflect.go
+++ /dev/null
@@ -1,277 +0,0 @@
-package muxt
-
-import (
- "bytes"
- "context"
- "fmt"
- "go/ast"
- "html/template"
- "io"
- "log/slog"
- "net/http"
- "reflect"
- "sync"
-)
-
-type Options struct {
- logger *slog.Logger
- receiver any
- execute ExecuteFunc[any]
- error ExecuteFunc[error]
-}
-
-func newOptions() Options {
- return Options{
- logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
- Level: slog.LevelError,
- })),
- receiver: nil,
- execute: defaultExecute,
- error: internalServerErrorErrorFunc,
- }
-}
-
-func WithStructuredLogger(log *slog.Logger) Options { return newOptions().WithStructuredLogger(log) }
-func WithReceiver(r any) Options { return newOptions().WithReceiver(r) }
-func WithDataFunc(ex ExecuteFunc[any]) Options { return newOptions().WithDataFunc(ex) }
-func WithErrorFunc(ex ExecuteFunc[error]) Options { return newOptions().WithErrorFunc(ex) }
-func WithNoopErrorFunc() Options { return newOptions().WithNoopErrorFunc() }
-func With500ErrorFunc() Options { return newOptions().With500ErrorFunc() }
-
-func (o Options) WithStructuredLogger(log *slog.Logger) Options {
- o.logger = log
- return o
-}
-
-func (o Options) WithReceiver(r any) Options {
- o.receiver = r
- return o
-}
-
-func (o Options) WithDataFunc(ex ExecuteFunc[any]) Options {
- o.execute = ex
- return o
-}
-
-func (o Options) WithErrorFunc(ex ExecuteFunc[error]) Options {
- o.error = ex
- return o
-}
-
-func (o Options) WithNoopErrorFunc() Options {
- o.error = noopErrorFunc
- return o
-}
-
-func (o Options) With500ErrorFunc() Options {
- o.error = internalServerErrorErrorFunc
- return o
-}
-
-func noopErrorFunc(http.ResponseWriter, *http.Request, *template.Template, *slog.Logger, error) {}
-
-func internalServerErrorErrorFunc(res http.ResponseWriter, _ *http.Request, t *template.Template, logger *slog.Logger, err error) {
- logger.Error("handler error", "error", err, "template", t.Name())
- http.Error(res, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
-}
-
-func applyOptions(options []Options) *Options {
- result := newOptions()
- for _, o := range options {
- if o.logger != nil {
- result.logger = o.logger
- }
- if o.receiver != nil {
- result.receiver = o.receiver
- }
- if o.execute != nil {
- result.execute = o.execute
- }
- if o.error != nil {
- result.error = o.error
- }
- }
- return &result
-}
-
-func Handlers(mux *http.ServeMux, ts *template.Template, options ...Options) error {
- o := applyOptions(options)
- patterns, err := TemplatePatterns(ts)
- if err != nil {
- return err
- }
- for _, pat := range patterns {
- t := ts.Lookup(pat.String())
- if pat.Handler == "" {
- mux.HandleFunc(pat.Route, simpleTemplateHandler(o.execute, t, o.logger))
- continue
- }
- h, err := newReflectHandlerFunc(o, t, pat)
- if err != nil {
- return fmt.Errorf("failed to create handler for %q: %w", pat.String(), err)
- }
- mux.HandleFunc(pat.Route, h)
- }
- return nil
-}
-
-func newReflectHandlerFunc(o *Options, t *template.Template, pat Pattern) (http.HandlerFunc, error) {
- m, err := serviceMethod(o, pat.FunIdent())
- if err != nil {
- return nil, err
- }
- inputs, err := generateInputsFunction(o, t, m.Type(), pat)
- if err != nil {
- return nil, err
- }
- return generateOutputsFunction(o, t, m, inputs)
-}
-
-type inputsFunc = func(res http.ResponseWriter, req *http.Request) []reflect.Value
-
-func generateOutputsFunction(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) (http.HandlerFunc, error) {
- methodType := method.Type()
- switch num := methodType.NumOut(); num {
- case 1:
- return valueResultHandler(o, t, method, inputs), nil
- case 2:
- if !methodType.Out(1).AssignableTo(reflect.TypeFor[error]()) {
- return nil, fmt.Errorf("the second result must be an error")
- }
- return valuesResultHandler(o, t, method, inputs), nil
- default:
- return nil, fmt.Errorf("method must either return (T) or (T, error)")
- }
-}
-
-func valueResultHandler(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) http.HandlerFunc {
- return func(res http.ResponseWriter, req *http.Request) {
- in := inputs(res, req)
- out := method.Call(in)
- o.execute(res, req, t, o.logger, out[0].Interface())
- }
-}
-
-func valuesResultHandler(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) http.HandlerFunc {
- return func(res http.ResponseWriter, req *http.Request) {
- in := inputs(res, req)
- out := method.Call(in)
- callRes, callErr := out[0], out[1]
- if !callErr.IsNil() {
- err := callErr.Interface().(error)
- o.error(res, req, t, o.logger, err)
- return
- }
- o.execute(res, req, t, o.logger, callRes.Interface())
- }
-}
-
-func serviceMethod(o *Options, method *ast.Ident) (reflect.Value, error) {
- if o.receiver == nil {
- return reflect.Value{}, fmt.Errorf("receiver is nil")
- }
- r := reflect.ValueOf(o.receiver)
- m := r.MethodByName(method.Name)
- if !m.IsValid() {
- return reflect.Value{}, fmt.Errorf("method %s not found on %s", method.Name, r.Type())
- }
- return m, nil
-}
-
-func generateInputsFunction(o *Options, t *template.Template, method reflect.Type, pat Pattern) (inputsFunc, error) {
- if method.NumIn() != len(pat.ArgIdents()) {
- return nil, fmt.Errorf("wrong number of arguments")
- }
- if len(pat.ArgIdents()) == 0 {
- return func(http.ResponseWriter, *http.Request) []reflect.Value {
- return nil
- }, nil
- }
- var args []string
- for i, argIdent := range pat.ArgIdents() {
- arg, err := typeCheckMethodParameters(method.In(i), argIdent)
- if err != nil {
- return nil, fmt.Errorf("method argument at index %d: %w", i, err)
- }
- args = append(args, arg)
- }
- return func(res http.ResponseWriter, req *http.Request) []reflect.Value {
- var in []reflect.Value
- for _, arg := range args {
- switch arg {
- case PatternScopeIdentifierHTTPResponse:
- in = append(in, reflect.ValueOf(res))
- case PatternScopeIdentifierHTTPRequest:
- in = append(in, reflect.ValueOf(req))
- case PatternScopeIdentifierContext:
- in = append(in, reflect.ValueOf(req.Context()))
- case PatternScopeIdentifierLogger:
- in = append(in, reflect.ValueOf(o.logger))
- case PatternScopeIdentifierTemplate:
- in = append(in, reflect.ValueOf(t))
- default:
- in = append(in, reflect.ValueOf(req.PathValue(arg)))
- }
- }
- return in
- }, nil
-}
-
-var argumentType = sync.OnceValue(func() func(argName string) (reflect.Type, error) {
- requestType := reflect.TypeFor[*http.Request]()
- contextType := reflect.TypeFor[context.Context]()
- responseType := reflect.TypeFor[http.ResponseWriter]()
- loggerType := reflect.TypeFor[*slog.Logger]()
- templateType := reflect.TypeFor[*template.Template]()
- stringType := reflect.TypeFor[string]()
- return func(argName string) (reflect.Type, error) {
- var argType reflect.Type
- switch argName {
- case PatternScopeIdentifierHTTPRequest:
- argType = requestType
- case PatternScopeIdentifierContext:
- argType = contextType
- case PatternScopeIdentifierHTTPResponse:
- argType = responseType
- case PatternScopeIdentifierLogger:
- argType = loggerType
- case PatternScopeIdentifierTemplate:
- argType = templateType
- default:
- argType = stringType
- }
- return argType, nil
- }
-})
-
-func typeCheckMethodParameters(paramType reflect.Type, arg *ast.Ident) (string, error) {
- argType, err := argumentType()(arg.Name)
- if err != nil {
- return arg.Name, err
- }
- if !argType.AssignableTo(paramType) {
- return arg.Name, fmt.Errorf("argument %s %s is not assignable to parameter type %s", arg.Name, argType, paramType)
- }
- return arg.Name, nil
-}
-
-func simpleTemplateHandler(ex ExecuteFunc[any], t *template.Template, logger *slog.Logger) http.HandlerFunc {
- return func(res http.ResponseWriter, req *http.Request) {
- ex(res, req, t, logger, req)
- }
-}
-
-type ExecuteFunc[T any] func(http.ResponseWriter, *http.Request, *template.Template, *slog.Logger, T)
-
-func defaultExecute(res http.ResponseWriter, req *http.Request, t *template.Template, logger *slog.Logger, data any) {
- var buf bytes.Buffer
- if err := t.Execute(&buf, data); err != nil {
- logger.Error("failed to render page", "method", req.Method, "path", req.URL.Path, "error", err)
- http.Error(res, "failed to render page", http.StatusInternalServerError)
- return
- }
- if _, err := buf.WriteTo(res); err != nil {
- logger.Error("failed to write full response", "method", req.Method, "path", req.URL.Path, "error", err)
- return
- }
-}
diff --git a/reflect_test.go b/reflect_test.go
deleted file mode 100644
index 33e98ba..0000000
--- a/reflect_test.go
+++ /dev/null
@@ -1,593 +0,0 @@
-package muxt_test
-
-import (
- "bytes"
- "context"
- "fmt"
- "html/template"
- "io"
- "log/slog"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "testing"
-
- "github.com/crhntr/dom/domtest"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "golang.org/x/net/html/atom"
-
- "github.com/crhntr/muxt"
- "github.com/crhntr/muxt/internal/example"
- "github.com/crhntr/muxt/internal/fake"
-)
-
-//go:generate counterfeiter -generate
-//counterfeiter:generate -o ./internal/fake/receiver.go --fake-name Receiver . receiver
-var _ receiver = (*fake.Receiver)(nil)
-
-//counterfeiter:generate -o ./internal/fake/response_writer.go --fake-name ResponseWriter net/http.ResponseWriter
-
-type (
- receiver interface {
- ListArticles(ctx context.Context) ([]example.Article, error)
- ToUpper(in ...rune) string
- Parse(string) []string
- GetComment(ctx context.Context, articleID, commentID int) (string, error)
- SomeString(ctx context.Context, x string) (string, error)
- TooManyResults() (int, int, int)
- NumAuthors() int
- CheckAuth(req *http.Request) (string, error)
- Handler(http.ResponseWriter, *http.Request) template.HTML
- ErrorHandler(http.ResponseWriter, *http.Request) (template.HTML, error)
- LogLines(*slog.Logger) int
- Template(*template.Template) template.HTML
- Type(any) string
- Tuple() (string, string)
- }
-)
-
-func TestRoutes(t *testing.T) {
- t.Run("GET index", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- /* language=gotemplate */
- `{{define "GET /" }}Hello, friend! {{end}}`,
- ))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts)
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- res := rec.Result()
-
- assert.Equal(t, http.StatusOK, res.StatusCode)
- })
-
- t.Run("when a handler is registered", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(
- /* language=gotemplate */
- `{{define "GET /articles ListArticles(ctx)" }}{{range .}}{{.Title}} {{end}} {{end}}`,
- ))
- as := new(fake.Receiver)
- articles := []example.Article{
- {ID: 1, Title: "Hello"},
- {ID: 2, Title: "Goodbye"},
- }
- as.ListArticlesReturns(articles, nil)
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(as))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/articles", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- assert.Equal(t, 1, as.ListArticlesCallCount())
-
- res := rec.Result()
-
- assert.Equal(t, http.StatusOK, res.StatusCode)
- fragment := domtest.DocumentFragmentResponse(t, res, atom.Body)
- listItems := fragment.QuerySelectorAll(`[data-id]`)
- assert.Equal(t, len(articles), listItems.Length())
- for i := 0; i < listItems.Length(); i++ {
- li := listItems.Item(i)
- assert.Equal(t, articles[i].Title, li.TextContent())
- assert.Equal(t, strconv.Itoa(articles[i].ID), li.GetAttribute("data-id"))
- }
- })
-
- t.Run("unexpected method", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(`{{define "CONNECT /articles" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts)
- require.ErrorContains(t, err, `CONNECT method not allowed`)
- })
-
- t.Run("no method", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(`{{define "/x/y" }}{{.Method}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts)
- require.NoError(t, err)
-
- for _, method := range []string{http.MethodGet, http.MethodPost} {
- req := httptest.NewRequest(method, "/x/y", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
- res := rec.Result()
-
- assert.Equal(t, http.StatusOK, res.StatusCode)
- body, _ := io.ReadAll(res.Body)
- assert.Equal(t, method, string(body))
- }
- })
-
- t.Run("selector must be an expression", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / var x int" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver)))
- require.ErrorContains(t, err, "failed to parse handler expression")
- })
-
- t.Run("function must be an identifier", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / func().Method(request)" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- rec := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(rec))
- require.ErrorContains(t, err, `expected function identifier`)
- })
-
- t.Run("receiver is nil and a method is expected", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Method(request)" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts)
- require.ErrorContains(t, err, "receiver is nil")
- })
-
- t.Run("method not found on basic type", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Foo(request)" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(100))
- require.ErrorContains(t, err, `method Foo not found on int`)
- })
-
- t.Run("ellipsis not allowed", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /{name} ToUpper(name...)" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, `unexpected ellipsis`)
- })
-
- t.Run("duplicate path param identifier", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /articles/{id}/comment/{id} GetComment(ctx, id, id)" }}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, `path parameter id defined at least twice`)
- })
-
- t.Run("path param is not an identifier ", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /{key-id} SomeString(ctx, key-id)"}}KEY{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, `path parameter name not permitted: "key-id" is not a Go identifier`)
- })
-
- for _, name := range []string{
- "request",
- "ctx",
- "response",
- "logger",
- "template",
- } {
- t.Run(name+" can not be used as a path parameter identifier", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(fmt.Sprintf(`{{define "GET /{%[1]s} Type(%[1]s)"}}{{.}}{{end}}`, name)))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, fmt.Sprintf(`the name %s is not allowed as a path paramenter it is alredy in scop`, name))
- })
-
- t.Run(name+" can be used when no handler is defined", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(fmt.Sprintf(`{{define "GET /{%[1]s}"}}{{.}}{{end}}`, name)))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
- })
- }
-
- t.Run("template execution fails", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Funcs(template.FuncMap{
- "errorNow": func() (string, error) { return "", fmt.Errorf("BANANA") },
- }).Parse(`{{define "GET / ListArticles(ctx)"}}{{ errorNow }}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
- res := rec.Result()
- assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
- })
-
- t.Run("write fails", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /"}}{{printf "%d" 199}}{{end}}`))
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{
- Level: slog.LevelDebug,
- }))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithStructuredLogger(logger))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(errorWriter{ResponseWriter: rec}, req)
- res := rec.Result()
- assert.Equal(t, http.StatusOK, res.StatusCode)
- assert.Contains(t, logBuffer.String(), "failed to write full response")
- })
-
- t.Run("too many results", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / TooManyResults()" }}{{.}}{{end}}"`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver)))
- require.ErrorContains(t, err, `method must either return (T) or (T, error)`)
- })
-
- t.Run("call fails", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-articles ListArticles(ctx)"}}{{len .}}{{end}}`))
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{
- Level: slog.LevelDebug,
- }))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.ListArticlesReturns(nil, fmt.Errorf("banana"))
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s).WithStructuredLogger(logger))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/number-of-articles", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
- res := rec.Result()
- assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
-
- body, _ := io.ReadAll(res.Body)
- assert.NotContains(t, string(body), "banana")
-
- assert.Contains(t, logBuffer.String(), "banana")
- })
-
- t.Run("not a function call", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-articles <-c"}}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, "expected call")
- })
-
- t.Run("single return", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-authors NumAuthors()"}}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.NumAuthorsReturns(234)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/number-of-authors", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
- res := rec.Result()
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- body, _ := io.ReadAll(res.Body)
-
- assert.Equal(t, string(body), "234")
- })
-
- t.Run("request as a parameter", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /auth CheckAuth(request)"}}OK{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.NumAuthorsReturns(234)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/auth", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
- res := rec.Result()
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- body, _ := io.ReadAll(res.Body)
-
- assert.Equal(t, string(body), "OK")
- })
-
- t.Run("non identifier params", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /site-owner GetComment(ctx, 3, 1+2)"}}OK{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.NumAuthorsReturns(234)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, `expected only argument expressions as arguments, argument at index 1 is: 3`)
- })
-
- t.Run("query param", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET /input/{in} Parse(in)"}}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.NumAuthorsReturns(234)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/input/peach", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- v := s.ParseArgsForCall(0)
- assert.Equal(t, "peach", v)
- })
-
- t.Run("unknown identifier", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Parse(enemy)"}}@{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- s.NumAuthorsReturns(234)
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, `unknown argument enemy at index 0`)
- })
-
- t.Run("full handler func signature", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Handler(response, request)"}}{{.}}{{end}}`))
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{
- Level: slog.LevelDebug,
- }))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- s.HandlerStub = func(writer http.ResponseWriter, request *http.Request) template.HTML {
- writer.WriteHeader(http.StatusCreated)
-
- return "Progressive "
- }
-
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s).WithStructuredLogger(logger))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/input/peach", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- res := rec.Result()
- assert.Equal(t, http.StatusCreated, res.StatusCode)
- })
-
- t.Run("method receives a template", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Template(template)"}}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/input/peach", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- res := rec.Result()
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- if assert.Equal(t, 1, s.TemplateCallCount()) {
- arg := s.TemplateArgsForCall(0)
- assert.Equal(t, "GET / Template(template)", arg.Name())
- }
- })
-
- t.Run("wrong parameter type", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "GET / Template(request)"}}{{.}}{{end}}`))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.ErrorContains(t, err, "method argument at index 0: argument request *http.Request is not assignable to parameter type *template.Template")
- })
-
- t.Run("handler uses a logger", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "POST /stdin LogLines(logger)"}}{{printf "lines: %d" .}}{{end}}`))
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{
- Level: slog.LevelDebug,
- }))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- s.LogLinesStub = func(logger *slog.Logger) int {
- logger.Info("some message")
- return 42
- }
-
- err := muxt.Handlers(mux, ts, muxt.WithStructuredLogger(logger).WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodPost, "/stdin", strings.NewReader(""))
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- res := rec.Result()
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- assert.Contains(t, logBuffer.String(), "some message")
- })
-
- t.Run("wrong number of arguments", func(t *testing.T) {
- ts := template.Must(template.New("simple path").Parse(`{{define "POST /stdin LogLines(ctx, logger)"}}{{printf "lines: %d" .}}{{end}}`))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver)))
- require.ErrorContains(t, err, "wrong number of arguments")
- })
-
- t.Run("custom execute function", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- /* language=gotemplate */
- `{{define "GET /" }}Hello, friend! {{end}}`,
- ))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithDataFunc(func(res http.ResponseWriter, req *http.Request, t *template.Template, logger *slog.Logger, data any) {
- res.WriteHeader(http.StatusBadRequest)
- }))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- mux.ServeHTTP(rec, req)
-
- res := rec.Result()
-
- assert.Equal(t, http.StatusBadRequest, res.StatusCode)
- })
- t.Run("custom execute function", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- `{{define "GET / Tuple()" }}{{.}} {{end}}`,
- ))
- mux := http.NewServeMux()
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver)))
- require.ErrorContains(t, err, "the second result must be an error")
- })
-
- t.Run("when the error handler is overwritten", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- `{{define "GET / ListArticles(ctx)" }}{{len .}} {{end}}`,
- ))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
- listErr := fmt.Errorf("banana")
- s.ListArticlesReturns(nil, listErr)
- const userFacingError = "🍌"
- err := muxt.Handlers(mux, ts, muxt.WithErrorFunc(func(res http.ResponseWriter, req *http.Request, ts *template.Template, logger *slog.Logger, err error) {
- assert.Equal(t, "GET / ListArticles(ctx)", ts.Name())
- assert.NotNil(t, logger)
- assert.Error(t, err)
- assert.Equal(t, err, listErr)
- res.WriteHeader(http.StatusBadRequest)
- _, _ = io.WriteString(res, userFacingError)
- }).WithReceiver(s))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- mux.ServeHTTP(rec, req)
- res := rec.Result()
- assert.Equal(t, http.StatusBadRequest, res.StatusCode)
- })
-
- t.Run("when the noop handler error func is configures", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- `{{define "GET / ErrorHandler(response, request)" }}{{.}}{{end}}`,
- ))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- const body = `Excuse You
`
- s.ErrorHandlerStub = func(res http.ResponseWriter, _ *http.Request) (template.HTML, error) {
- res.WriteHeader(http.StatusBadRequest)
- _, _ = io.WriteString(res, body)
- return "", fmt.Errorf("banana")
- }
-
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
-
- err := muxt.Handlers(mux, ts, muxt.WithNoopErrorFunc().WithReceiver(s).WithStructuredLogger(logger))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := new(fake.ResponseWriter)
- mux.ServeHTTP(res, req)
-
- assert.Equal(t, 1, res.WriteHeaderCallCount())
- assert.Equal(t, http.StatusBadRequest, res.WriteHeaderArgsForCall(0))
- assert.Equal(t, body, string(res.WriteArgsForCall(0)))
- assert.Empty(t, logBuffer.String())
- })
-
- t.Run("when the 500 handler error func is configured", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- `{{define "GET / ErrorHandler(response, request)" }}{{.}}{{end}}`,
- ))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- s.ErrorHandlerStub = func(res http.ResponseWriter, _ *http.Request) (template.HTML, error) {
- return "", fmt.Errorf("banana")
- }
-
- logBuffer := bytes.NewBuffer(nil)
- logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
-
- err := muxt.Handlers(mux, ts, muxt.With500ErrorFunc().WithReceiver(s).WithStructuredLogger(logger))
- require.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := new(fake.ResponseWriter)
- res.HeaderReturns(make(http.Header))
- mux.ServeHTTP(res, req)
-
- assert.Equal(t, 1, res.WriteHeaderCallCount())
- assert.Equal(t, http.StatusInternalServerError, res.WriteHeaderArgsForCall(0))
- assert.Equal(t, http.StatusText(http.StatusInternalServerError)+"\n", string(res.WriteArgsForCall(0)))
- assert.Contains(t, logBuffer.String(), "error=banana")
- })
-
- t.Run("when the path has an end of path delimiter", func(t *testing.T) {
- //
- ts := template.Must(template.New("simple path").Parse(
- `{{define "GET /{$} ListArticles(ctx)" }}{{len .}}{{end}}`,
- ))
- mux := http.NewServeMux()
- s := new(fake.Receiver)
-
- err := muxt.Handlers(mux, ts, muxt.WithReceiver(s))
- require.NoError(t, err)
- })
-}
-
-type errorWriter struct {
- http.ResponseWriter
-}
-
-func (w errorWriter) Write([]byte) (int, error) {
- return 0, fmt.Errorf("banna")
-}