diff --git a/README.md b/README.md index acb3b0a..4a511eb 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,35 @@ This is especially helpful when you are writing HTMX. -## Example +Given the following files -The "define" blocks in the following template register handlers with the server mux. +### main.go -The http method, http host, and path semantics match those of in the HTTP package. +```go +package main -This library extends this to add custom data handler invocations see "PATCH /fruits/{fruit}". It is configured to call EditRow on template parse time provided receiver. +import ( + "html/template" + "log" + "net/http" +) -When no handler method is specified in the "declare" string (as is the case with "GET /fruits/{fruit}/edit" in the example), the template receives the *http.Request. +//go:embed *.gohtml +var templateSource embed.FS +var templates = template.Must(template.ParseFS(templateSource, "*")) + +type Backend struct {} + +func main() { + mux := http.NewServeMux() + muxt := Routes(mux, templates) + + +} +``` + +### index.gohtml ```html @@ -36,7 +55,6 @@ When no handler method is specified in the "declare" string (as is the case with - {{- range . -}} {{- block "fruit row" . -}} {{ .Fruit }} @@ -45,9 +63,11 @@ When no handler method is specified in the "declare" string (as is the case with {{- end -}} + + {{- define "GET /{} List(ctx)" -}} + {{template "index.gohtml" .}} {{- end -}} - {{- define "GET /fruits/{fruit}/edit" -}} {{ .PathValue "fruit" }} diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 9b59293..0a6d62f 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -27,6 +27,7 @@ type Generate struct { templatesVariable string outputFilename string routesFunction string + receiverIdent string } func (g Generate) ImportReceiverMethods(tp, method string) (*ast.FuncType, []*ast.ImportSpec, bool) { @@ -43,9 +44,11 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) flagSet.StringVar(&config.templatesVariable, "templates-variable", muxt.DefaultTemplatesVariableName, "templates variable name") flagSet.StringVar(&config.outputFilename, "output-file", "template_routes.go", "file name of generated output") flagSet.StringVar(&config.routesFunction, "routes-func", muxt.DefaultRoutesFunctionName, "file name of generated output") + flagSet.StringVar(&config.receiverIdent, "receiver", "", "static receiver type identifier") if err := flagSet.Parse(args); err != nil { return err } + _ = os.Remove(filepath.Join(workingDirectory, config.outputFilename)) list, err := packages.Load(&packages.Config{ Mode: packages.NeedFiles | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles, Dir: workingDirectory, @@ -72,7 +75,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) return err } out := log.New(stdout, "", 0) - s, err := muxt.Generate(patterns, config.goPackage, config.templatesVariable, config.routesFunction, "", config.Package.Fset, config.Package.Syntax, config.Package.Syntax, out) + s, err := muxt.Generate(patterns, config.goPackage, config.templatesVariable, config.routesFunction, config.receiverIdent, config.Package.Fset, config.Package.Syntax, config.Package.Syntax, out) if err != nil { return err } diff --git a/cmd/muxt/main.go b/cmd/muxt/main.go index bb407b9..32275ed 100644 --- a/cmd/muxt/main.go +++ b/cmd/muxt/main.go @@ -28,7 +28,7 @@ func command(wd string, args []string, getEnv func(string) string, stdout, stder func handleError(err error) int { if err != nil { - _, _ = os.Stderr.WriteString(err.Error()) + _, _ = os.Stderr.WriteString(err.Error() + "\n") return 1 } return 0 diff --git a/example/index.gohtml b/example/index.gohtml new file mode 100644 index 0000000..3a34dfa --- /dev/null +++ b/example/index.gohtml @@ -0,0 +1,76 @@ + + +{{block "head" "example"}} + + + {{.}} + + + + + +{{end}} + +
+ + + + + + + + + + {{- define "fruit row" -}} + + + + + {{- end -}} + + {{range .}} + {{template "fruit row" .}} + {{end}} + + {{- define "GET /{$} List(ctx)" -}} + {{template "index.gohtml" .}} + {{- end -}} + + {{- define "GET /fruits/{fruit}/edit GetFormEditRow(fruit)" -}} + + + + + {{- end -}} + + {{- define "PATCH /fruits/{fruit} SubmitFormEditRow(request, fruit)" }} + {{- if .Error -}} + {{template "GET /fruits/{fruit}/edit GetFormEditRow(fruit)" .}} + {{- else -}} + {{template "fruit row" .Row}} + {{- end -}} + {{ end -}} + + +
FruitCount
{{ .Name }}{{ .Value }}
{{ .Row.Name }} +
+ + +
+

{{.Error}}

+
+
+ + + +{{define "GET /help"}} + + +{{template "head" "Help"}} + +
+ Hello, help! +
+ + +{{end}} \ No newline at end of file diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..6fe186b --- /dev/null +++ b/example/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "context" + "embed" + "fmt" + "html/template" + "log" + "net/http" + "strconv" +) + +//go:embed *.gohtml +var templateSource embed.FS + +var templates = template.Must(template.ParseFS(templateSource, "*")) + +type Backend struct { + data []Row +} + +type EditRowPage struct { + Row Row + Error error +} + +func (b *Backend) SubmitFormEditRow(request *http.Request, fruit string) EditRowPage { + count, err := strconv.Atoi(request.FormValue("count")) + if err != nil { + return EditRowPage{Error: err, Row: Row{Name: fruit}} + } + for i := range b.data { + if b.data[i].Name == fruit { + b.data[i].Value = count + return EditRowPage{Error: nil, Row: b.data[i]} + } + } + return EditRowPage{Error: fmt.Errorf("fruit not found")} +} + +func (b *Backend) GetFormEditRow(fruit string) EditRowPage { + for i := range b.data { + if b.data[i].Name == fruit { + return EditRowPage{Error: nil, Row: b.data[i]} + } + } + return EditRowPage{Error: fmt.Errorf("fruit not found")} +} + +type Row struct { + Name string + Value int +} + +func (b *Backend) List(_ context.Context) []Row { return b.data } + +//go:generate muxt generate --receiver Backend + +func main() { + backend := &Backend{ + data: []Row{ + {Name: "Peach", Value: 10}, + {Name: "Plum", Value: 20}, + {Name: "Pineapple", Value: 2}, + }, + } + mux := http.NewServeMux() + Routes(mux, backend) + log.Fatal(http.ListenAndServe(":8080", mux)) +} diff --git a/example/template_routes.go b/example/template_routes.go new file mode 100644 index 0000000..a7ba808 --- /dev/null +++ b/example/template_routes.go @@ -0,0 +1,46 @@ +// Code generated by muxt. DO NOT EDIT. + +package main + +import ( + "context" + "net/http" + "bytes" + "html/template" +) + +type RoutesReceiver interface { + SubmitFormEditRow(request *http.Request, fruit string) EditRowPage + GetFormEditRow(fruit string) EditRowPage + List(_ context.Context) []Row +} + +func Routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("PATCH /fruits/{fruit}", func(response http.ResponseWriter, request *http.Request) { + fruit := request.PathValue("fruit") + data := receiver.SubmitFormEditRow(request, fruit) + execute(response, request, templates.Lookup("PATCH /fruits/{fruit} SubmitFormEditRow(request, fruit)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /fruits/{fruit}/edit", func(response http.ResponseWriter, request *http.Request) { + fruit := request.PathValue("fruit") + data := receiver.GetFormEditRow(fruit) + execute(response, request, templates.Lookup("GET /fruits/{fruit}/edit GetFormEditRow(fruit)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /help", func(response http.ResponseWriter, request *http.Request) { + execute(response, request, templates.Lookup("GET /help"), http.StatusOK, request) + }) + mux.HandleFunc("GET /{$}", func(response http.ResponseWriter, request *http.Request) { + ctx := request.Context() + data := receiver.List(ctx) + execute(response, request, templates.Lookup("GET /{$} List(ctx)"), http.StatusOK, data) + }) +} +func execute(response http.ResponseWriter, request *http.Request, t *template.Template, code int, data any) { + buf := bytes.NewBuffer(nil) + if err := t.Execute(buf, data); err != nil { + http.Error(response, err.Error(), http.StatusOK) + return + } + response.WriteHeader(code) + _, _ = buf.WriteTo(response) +} diff --git a/generate.go b/generate.go index 35630b4..134e8f9 100644 --- a/generate.go +++ b/generate.go @@ -77,11 +77,11 @@ func Generate(patterns []Pattern, packageName, templatesVariableName, routesFunc Type: method, }) } - handlerFunc, handerSignatureImports, err := pattern.funcLit(templatesVariableName, method) + handlerFunc, methodImports, err := pattern.funcLit(templatesVariableName, method) if err != nil { return "", err } - imports = source.SortImports(append(imports, handerSignatureImports...)) + imports = source.SortImports(append(imports, methodImports...)) routes.Body.List = append(routes.Body.List, pattern.callHandleFunc(handlerFunc)) log.Printf("%s has route for %s", routesFunctionName, pattern.String()) } @@ -163,12 +163,12 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType) } } + const dataVarIdent = "data" if method != nil && len(method.Results.List) > 1 { - dataVar := ast.NewIdent(dataVarIdent) errVar := ast.NewIdent("err") lit.Body.List = append(lit.Body.List, - &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVar.Name), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, + &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent), ast.NewIdent(errVar.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}, &ast.IfStmt{ Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")}, Body: &ast.BlockStmt{ @@ -194,11 +194,10 @@ func (def Pattern) funcLit(templatesVariableIdent string, method *ast.FuncType) }, }, ) + } else { + lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) } - - data := ast.NewIdent(dataVarIdent) - lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(data.Name)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) - lit.Body.List = append(lit.Body.List, def.executeCall(ast.NewIdent(templatesVariableIdent), httpStatusCode(httpStatusCode200Ident), data)) + lit.Body.List = append(lit.Body.List, def.executeCall(ast.NewIdent(templatesVariableIdent), httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent))) return lit, imports, nil } @@ -422,7 +421,16 @@ func (def Pattern) httpRequestReceiverTemplateHandlerFunc(templatesVariableName } func (def Pattern) matchReceiver(funcDecl *ast.FuncDecl, receiverTypeIdent string) bool { - return funcDecl.Name.Name == def.fun.Name && funcDecl.Recv != nil && len(funcDecl.Recv.List) == 1 && funcDecl.Recv.List[0].Type.(*ast.Ident).Name == receiverTypeIdent + if funcDecl == nil || funcDecl.Name == nil || funcDecl.Name.Name != def.fun.Name || funcDecl.Recv == nil && len(funcDecl.Recv.List) < 1 { + return false + } + exp := funcDecl.Recv.List[0].Type + if star, ok := exp.(*ast.StarExpr); ok { + exp = star.X + } + ident, ok := exp.(*ast.Ident) + return ok && ident.Name == receiverTypeIdent + } func executeFuncDecl() *ast.FuncDecl { diff --git a/reflect.go b/reflect.go deleted file mode 100644 index 9b61299..0000000 --- a/reflect.go +++ /dev/null @@ -1,277 +0,0 @@ -package muxt - -import ( - "bytes" - "context" - "fmt" - "go/ast" - "html/template" - "io" - "log/slog" - "net/http" - "reflect" - "sync" -) - -type Options struct { - logger *slog.Logger - receiver any - execute ExecuteFunc[any] - error ExecuteFunc[error] -} - -func newOptions() Options { - return Options{ - logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ - Level: slog.LevelError, - })), - receiver: nil, - execute: defaultExecute, - error: internalServerErrorErrorFunc, - } -} - -func WithStructuredLogger(log *slog.Logger) Options { return newOptions().WithStructuredLogger(log) } -func WithReceiver(r any) Options { return newOptions().WithReceiver(r) } -func WithDataFunc(ex ExecuteFunc[any]) Options { return newOptions().WithDataFunc(ex) } -func WithErrorFunc(ex ExecuteFunc[error]) Options { return newOptions().WithErrorFunc(ex) } -func WithNoopErrorFunc() Options { return newOptions().WithNoopErrorFunc() } -func With500ErrorFunc() Options { return newOptions().With500ErrorFunc() } - -func (o Options) WithStructuredLogger(log *slog.Logger) Options { - o.logger = log - return o -} - -func (o Options) WithReceiver(r any) Options { - o.receiver = r - return o -} - -func (o Options) WithDataFunc(ex ExecuteFunc[any]) Options { - o.execute = ex - return o -} - -func (o Options) WithErrorFunc(ex ExecuteFunc[error]) Options { - o.error = ex - return o -} - -func (o Options) WithNoopErrorFunc() Options { - o.error = noopErrorFunc - return o -} - -func (o Options) With500ErrorFunc() Options { - o.error = internalServerErrorErrorFunc - return o -} - -func noopErrorFunc(http.ResponseWriter, *http.Request, *template.Template, *slog.Logger, error) {} - -func internalServerErrorErrorFunc(res http.ResponseWriter, _ *http.Request, t *template.Template, logger *slog.Logger, err error) { - logger.Error("handler error", "error", err, "template", t.Name()) - http.Error(res, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) -} - -func applyOptions(options []Options) *Options { - result := newOptions() - for _, o := range options { - if o.logger != nil { - result.logger = o.logger - } - if o.receiver != nil { - result.receiver = o.receiver - } - if o.execute != nil { - result.execute = o.execute - } - if o.error != nil { - result.error = o.error - } - } - return &result -} - -func Handlers(mux *http.ServeMux, ts *template.Template, options ...Options) error { - o := applyOptions(options) - patterns, err := TemplatePatterns(ts) - if err != nil { - return err - } - for _, pat := range patterns { - t := ts.Lookup(pat.String()) - if pat.Handler == "" { - mux.HandleFunc(pat.Route, simpleTemplateHandler(o.execute, t, o.logger)) - continue - } - h, err := newReflectHandlerFunc(o, t, pat) - if err != nil { - return fmt.Errorf("failed to create handler for %q: %w", pat.String(), err) - } - mux.HandleFunc(pat.Route, h) - } - return nil -} - -func newReflectHandlerFunc(o *Options, t *template.Template, pat Pattern) (http.HandlerFunc, error) { - m, err := serviceMethod(o, pat.FunIdent()) - if err != nil { - return nil, err - } - inputs, err := generateInputsFunction(o, t, m.Type(), pat) - if err != nil { - return nil, err - } - return generateOutputsFunction(o, t, m, inputs) -} - -type inputsFunc = func(res http.ResponseWriter, req *http.Request) []reflect.Value - -func generateOutputsFunction(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) (http.HandlerFunc, error) { - methodType := method.Type() - switch num := methodType.NumOut(); num { - case 1: - return valueResultHandler(o, t, method, inputs), nil - case 2: - if !methodType.Out(1).AssignableTo(reflect.TypeFor[error]()) { - return nil, fmt.Errorf("the second result must be an error") - } - return valuesResultHandler(o, t, method, inputs), nil - default: - return nil, fmt.Errorf("method must either return (T) or (T, error)") - } -} - -func valueResultHandler(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) http.HandlerFunc { - return func(res http.ResponseWriter, req *http.Request) { - in := inputs(res, req) - out := method.Call(in) - o.execute(res, req, t, o.logger, out[0].Interface()) - } -} - -func valuesResultHandler(o *Options, t *template.Template, method reflect.Value, inputs inputsFunc) http.HandlerFunc { - return func(res http.ResponseWriter, req *http.Request) { - in := inputs(res, req) - out := method.Call(in) - callRes, callErr := out[0], out[1] - if !callErr.IsNil() { - err := callErr.Interface().(error) - o.error(res, req, t, o.logger, err) - return - } - o.execute(res, req, t, o.logger, callRes.Interface()) - } -} - -func serviceMethod(o *Options, method *ast.Ident) (reflect.Value, error) { - if o.receiver == nil { - return reflect.Value{}, fmt.Errorf("receiver is nil") - } - r := reflect.ValueOf(o.receiver) - m := r.MethodByName(method.Name) - if !m.IsValid() { - return reflect.Value{}, fmt.Errorf("method %s not found on %s", method.Name, r.Type()) - } - return m, nil -} - -func generateInputsFunction(o *Options, t *template.Template, method reflect.Type, pat Pattern) (inputsFunc, error) { - if method.NumIn() != len(pat.ArgIdents()) { - return nil, fmt.Errorf("wrong number of arguments") - } - if len(pat.ArgIdents()) == 0 { - return func(http.ResponseWriter, *http.Request) []reflect.Value { - return nil - }, nil - } - var args []string - for i, argIdent := range pat.ArgIdents() { - arg, err := typeCheckMethodParameters(method.In(i), argIdent) - if err != nil { - return nil, fmt.Errorf("method argument at index %d: %w", i, err) - } - args = append(args, arg) - } - return func(res http.ResponseWriter, req *http.Request) []reflect.Value { - var in []reflect.Value - for _, arg := range args { - switch arg { - case PatternScopeIdentifierHTTPResponse: - in = append(in, reflect.ValueOf(res)) - case PatternScopeIdentifierHTTPRequest: - in = append(in, reflect.ValueOf(req)) - case PatternScopeIdentifierContext: - in = append(in, reflect.ValueOf(req.Context())) - case PatternScopeIdentifierLogger: - in = append(in, reflect.ValueOf(o.logger)) - case PatternScopeIdentifierTemplate: - in = append(in, reflect.ValueOf(t)) - default: - in = append(in, reflect.ValueOf(req.PathValue(arg))) - } - } - return in - }, nil -} - -var argumentType = sync.OnceValue(func() func(argName string) (reflect.Type, error) { - requestType := reflect.TypeFor[*http.Request]() - contextType := reflect.TypeFor[context.Context]() - responseType := reflect.TypeFor[http.ResponseWriter]() - loggerType := reflect.TypeFor[*slog.Logger]() - templateType := reflect.TypeFor[*template.Template]() - stringType := reflect.TypeFor[string]() - return func(argName string) (reflect.Type, error) { - var argType reflect.Type - switch argName { - case PatternScopeIdentifierHTTPRequest: - argType = requestType - case PatternScopeIdentifierContext: - argType = contextType - case PatternScopeIdentifierHTTPResponse: - argType = responseType - case PatternScopeIdentifierLogger: - argType = loggerType - case PatternScopeIdentifierTemplate: - argType = templateType - default: - argType = stringType - } - return argType, nil - } -}) - -func typeCheckMethodParameters(paramType reflect.Type, arg *ast.Ident) (string, error) { - argType, err := argumentType()(arg.Name) - if err != nil { - return arg.Name, err - } - if !argType.AssignableTo(paramType) { - return arg.Name, fmt.Errorf("argument %s %s is not assignable to parameter type %s", arg.Name, argType, paramType) - } - return arg.Name, nil -} - -func simpleTemplateHandler(ex ExecuteFunc[any], t *template.Template, logger *slog.Logger) http.HandlerFunc { - return func(res http.ResponseWriter, req *http.Request) { - ex(res, req, t, logger, req) - } -} - -type ExecuteFunc[T any] func(http.ResponseWriter, *http.Request, *template.Template, *slog.Logger, T) - -func defaultExecute(res http.ResponseWriter, req *http.Request, t *template.Template, logger *slog.Logger, data any) { - var buf bytes.Buffer - if err := t.Execute(&buf, data); err != nil { - logger.Error("failed to render page", "method", req.Method, "path", req.URL.Path, "error", err) - http.Error(res, "failed to render page", http.StatusInternalServerError) - return - } - if _, err := buf.WriteTo(res); err != nil { - logger.Error("failed to write full response", "method", req.Method, "path", req.URL.Path, "error", err) - return - } -} diff --git a/reflect_test.go b/reflect_test.go deleted file mode 100644 index 33e98ba..0000000 --- a/reflect_test.go +++ /dev/null @@ -1,593 +0,0 @@ -package muxt_test - -import ( - "bytes" - "context" - "fmt" - "html/template" - "io" - "log/slog" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" - - "github.com/crhntr/dom/domtest" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/html/atom" - - "github.com/crhntr/muxt" - "github.com/crhntr/muxt/internal/example" - "github.com/crhntr/muxt/internal/fake" -) - -//go:generate counterfeiter -generate -//counterfeiter:generate -o ./internal/fake/receiver.go --fake-name Receiver . receiver -var _ receiver = (*fake.Receiver)(nil) - -//counterfeiter:generate -o ./internal/fake/response_writer.go --fake-name ResponseWriter net/http.ResponseWriter - -type ( - receiver interface { - ListArticles(ctx context.Context) ([]example.Article, error) - ToUpper(in ...rune) string - Parse(string) []string - GetComment(ctx context.Context, articleID, commentID int) (string, error) - SomeString(ctx context.Context, x string) (string, error) - TooManyResults() (int, int, int) - NumAuthors() int - CheckAuth(req *http.Request) (string, error) - Handler(http.ResponseWriter, *http.Request) template.HTML - ErrorHandler(http.ResponseWriter, *http.Request) (template.HTML, error) - LogLines(*slog.Logger) int - Template(*template.Template) template.HTML - Type(any) string - Tuple() (string, string) - } -) - -func TestRoutes(t *testing.T) { - t.Run("GET index", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - /* language=gotemplate */ - `{{define "GET /" }}

Hello, friend!

{{end}}`, - )) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - res := rec.Result() - - assert.Equal(t, http.StatusOK, res.StatusCode) - }) - - t.Run("when a handler is registered", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse( - /* language=gotemplate */ - `{{define "GET /articles ListArticles(ctx)" }}{{end}}`, - )) - as := new(fake.Receiver) - articles := []example.Article{ - {ID: 1, Title: "Hello"}, - {ID: 2, Title: "Goodbye"}, - } - as.ListArticlesReturns(articles, nil) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(as)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/articles", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - assert.Equal(t, 1, as.ListArticlesCallCount()) - - res := rec.Result() - - assert.Equal(t, http.StatusOK, res.StatusCode) - fragment := domtest.DocumentFragmentResponse(t, res, atom.Body) - listItems := fragment.QuerySelectorAll(`[data-id]`) - assert.Equal(t, len(articles), listItems.Length()) - for i := 0; i < listItems.Length(); i++ { - li := listItems.Item(i) - assert.Equal(t, articles[i].Title, li.TextContent()) - assert.Equal(t, strconv.Itoa(articles[i].ID), li.GetAttribute("data-id")) - } - }) - - t.Run("unexpected method", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse(`{{define "CONNECT /articles" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts) - require.ErrorContains(t, err, `CONNECT method not allowed`) - }) - - t.Run("no method", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse(`{{define "/x/y" }}{{.Method}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts) - require.NoError(t, err) - - for _, method := range []string{http.MethodGet, http.MethodPost} { - req := httptest.NewRequest(method, "/x/y", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - res := rec.Result() - - assert.Equal(t, http.StatusOK, res.StatusCode) - body, _ := io.ReadAll(res.Body) - assert.Equal(t, method, string(body)) - } - }) - - t.Run("selector must be an expression", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / var x int" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver))) - require.ErrorContains(t, err, "failed to parse handler expression") - }) - - t.Run("function must be an identifier", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / func().Method(request)" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - rec := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(rec)) - require.ErrorContains(t, err, `expected function identifier`) - }) - - t.Run("receiver is nil and a method is expected", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Method(request)" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts) - require.ErrorContains(t, err, "receiver is nil") - }) - - t.Run("method not found on basic type", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Foo(request)" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(100)) - require.ErrorContains(t, err, `method Foo not found on int`) - }) - - t.Run("ellipsis not allowed", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /{name} ToUpper(name...)" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, `unexpected ellipsis`) - }) - - t.Run("duplicate path param identifier", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /articles/{id}/comment/{id} GetComment(ctx, id, id)" }}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, `path parameter id defined at least twice`) - }) - - t.Run("path param is not an identifier ", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /{key-id} SomeString(ctx, key-id)"}}KEY{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, `path parameter name not permitted: "key-id" is not a Go identifier`) - }) - - for _, name := range []string{ - "request", - "ctx", - "response", - "logger", - "template", - } { - t.Run(name+" can not be used as a path parameter identifier", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(fmt.Sprintf(`{{define "GET /{%[1]s} Type(%[1]s)"}}{{.}}{{end}}`, name))) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, fmt.Sprintf(`the name %s is not allowed as a path paramenter it is alredy in scop`, name)) - }) - - t.Run(name+" can be used when no handler is defined", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(fmt.Sprintf(`{{define "GET /{%[1]s}"}}{{.}}{{end}}`, name))) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - }) - } - - t.Run("template execution fails", func(t *testing.T) { - ts := template.Must(template.New("simple path").Funcs(template.FuncMap{ - "errorNow": func() (string, error) { return "", fmt.Errorf("BANANA") }, - }).Parse(`{{define "GET / ListArticles(ctx)"}}{{ errorNow }}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - res := rec.Result() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - }) - - t.Run("write fails", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /"}}{{printf "%d" 199}}{{end}}`)) - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithStructuredLogger(logger)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(errorWriter{ResponseWriter: rec}, req) - res := rec.Result() - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Contains(t, logBuffer.String(), "failed to write full response") - }) - - t.Run("too many results", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / TooManyResults()" }}{{.}}{{end}}"`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver))) - require.ErrorContains(t, err, `method must either return (T) or (T, error)`) - }) - - t.Run("call fails", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-articles ListArticles(ctx)"}}{{len .}}{{end}}`)) - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.ListArticlesReturns(nil, fmt.Errorf("banana")) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s).WithStructuredLogger(logger)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/number-of-articles", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - res := rec.Result() - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - body, _ := io.ReadAll(res.Body) - assert.NotContains(t, string(body), "banana") - - assert.Contains(t, logBuffer.String(), "banana") - }) - - t.Run("not a function call", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-articles <-c"}}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, "expected call") - }) - - t.Run("single return", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /number-of-authors NumAuthors()"}}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.NumAuthorsReturns(234) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/number-of-authors", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - res := rec.Result() - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, _ := io.ReadAll(res.Body) - - assert.Equal(t, string(body), "234") - }) - - t.Run("request as a parameter", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /auth CheckAuth(request)"}}OK{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.NumAuthorsReturns(234) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/auth", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - res := rec.Result() - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, _ := io.ReadAll(res.Body) - - assert.Equal(t, string(body), "OK") - }) - - t.Run("non identifier params", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /site-owner GetComment(ctx, 3, 1+2)"}}OK{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.NumAuthorsReturns(234) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, `expected only argument expressions as arguments, argument at index 1 is: 3`) - }) - - t.Run("query param", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET /input/{in} Parse(in)"}}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.NumAuthorsReturns(234) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/input/peach", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - v := s.ParseArgsForCall(0) - assert.Equal(t, "peach", v) - }) - - t.Run("unknown identifier", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Parse(enemy)"}}@{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - s.NumAuthorsReturns(234) - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, `unknown argument enemy at index 0`) - }) - - t.Run("full handler func signature", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Handler(response, request)"}}{{.}}{{end}}`)) - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) - mux := http.NewServeMux() - s := new(fake.Receiver) - - s.HandlerStub = func(writer http.ResponseWriter, request *http.Request) template.HTML { - writer.WriteHeader(http.StatusCreated) - - return "Progressive" - } - - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s).WithStructuredLogger(logger)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/input/peach", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - res := rec.Result() - assert.Equal(t, http.StatusCreated, res.StatusCode) - }) - - t.Run("method receives a template", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Template(template)"}}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/input/peach", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - res := rec.Result() - assert.Equal(t, http.StatusOK, res.StatusCode) - - if assert.Equal(t, 1, s.TemplateCallCount()) { - arg := s.TemplateArgsForCall(0) - assert.Equal(t, "GET / Template(template)", arg.Name()) - } - }) - - t.Run("wrong parameter type", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "GET / Template(request)"}}{{.}}{{end}}`)) - mux := http.NewServeMux() - s := new(fake.Receiver) - - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.ErrorContains(t, err, "method argument at index 0: argument request *http.Request is not assignable to parameter type *template.Template") - }) - - t.Run("handler uses a logger", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "POST /stdin LogLines(logger)"}}{{printf "lines: %d" .}}{{end}}`)) - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{ - Level: slog.LevelDebug, - })) - mux := http.NewServeMux() - s := new(fake.Receiver) - - s.LogLinesStub = func(logger *slog.Logger) int { - logger.Info("some message") - return 42 - } - - err := muxt.Handlers(mux, ts, muxt.WithStructuredLogger(logger).WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodPost, "/stdin", strings.NewReader("")) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - res := rec.Result() - assert.Equal(t, http.StatusOK, res.StatusCode) - - assert.Contains(t, logBuffer.String(), "some message") - }) - - t.Run("wrong number of arguments", func(t *testing.T) { - ts := template.Must(template.New("simple path").Parse(`{{define "POST /stdin LogLines(ctx, logger)"}}{{printf "lines: %d" .}}{{end}}`)) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver))) - require.ErrorContains(t, err, "wrong number of arguments") - }) - - t.Run("custom execute function", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - /* language=gotemplate */ - `{{define "GET /" }}

Hello, friend!

{{end}}`, - )) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithDataFunc(func(res http.ResponseWriter, req *http.Request, t *template.Template, logger *slog.Logger, data any) { - res.WriteHeader(http.StatusBadRequest) - })) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - mux.ServeHTTP(rec, req) - - res := rec.Result() - - assert.Equal(t, http.StatusBadRequest, res.StatusCode) - }) - t.Run("custom execute function", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - `{{define "GET / Tuple()" }}

{{.}}

{{end}}`, - )) - mux := http.NewServeMux() - err := muxt.Handlers(mux, ts, muxt.WithReceiver(new(fake.Receiver))) - require.ErrorContains(t, err, "the second result must be an error") - }) - - t.Run("when the error handler is overwritten", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - `{{define "GET / ListArticles(ctx)" }}

{{len .}}

{{end}}`, - )) - mux := http.NewServeMux() - s := new(fake.Receiver) - listErr := fmt.Errorf("banana") - s.ListArticlesReturns(nil, listErr) - const userFacingError = "🍌" - err := muxt.Handlers(mux, ts, muxt.WithErrorFunc(func(res http.ResponseWriter, req *http.Request, ts *template.Template, logger *slog.Logger, err error) { - assert.Equal(t, "GET / ListArticles(ctx)", ts.Name()) - assert.NotNil(t, logger) - assert.Error(t, err) - assert.Equal(t, err, listErr) - res.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(res, userFacingError) - }).WithReceiver(s)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - res := rec.Result() - assert.Equal(t, http.StatusBadRequest, res.StatusCode) - }) - - t.Run("when the noop handler error func is configures", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - `{{define "GET / ErrorHandler(response, request)" }}{{.}}{{end}}`, - )) - mux := http.NewServeMux() - s := new(fake.Receiver) - - const body = `

Excuse You

` - s.ErrorHandlerStub = func(res http.ResponseWriter, _ *http.Request) (template.HTML, error) { - res.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(res, body) - return "", fmt.Errorf("banana") - } - - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - - err := muxt.Handlers(mux, ts, muxt.WithNoopErrorFunc().WithReceiver(s).WithStructuredLogger(logger)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := new(fake.ResponseWriter) - mux.ServeHTTP(res, req) - - assert.Equal(t, 1, res.WriteHeaderCallCount()) - assert.Equal(t, http.StatusBadRequest, res.WriteHeaderArgsForCall(0)) - assert.Equal(t, body, string(res.WriteArgsForCall(0))) - assert.Empty(t, logBuffer.String()) - }) - - t.Run("when the 500 handler error func is configured", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - `{{define "GET / ErrorHandler(response, request)" }}{{.}}{{end}}`, - )) - mux := http.NewServeMux() - s := new(fake.Receiver) - - s.ErrorHandlerStub = func(res http.ResponseWriter, _ *http.Request) (template.HTML, error) { - return "", fmt.Errorf("banana") - } - - logBuffer := bytes.NewBuffer(nil) - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - - err := muxt.Handlers(mux, ts, muxt.With500ErrorFunc().WithReceiver(s).WithStructuredLogger(logger)) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := new(fake.ResponseWriter) - res.HeaderReturns(make(http.Header)) - mux.ServeHTTP(res, req) - - assert.Equal(t, 1, res.WriteHeaderCallCount()) - assert.Equal(t, http.StatusInternalServerError, res.WriteHeaderArgsForCall(0)) - assert.Equal(t, http.StatusText(http.StatusInternalServerError)+"\n", string(res.WriteArgsForCall(0))) - assert.Contains(t, logBuffer.String(), "error=banana") - }) - - t.Run("when the path has an end of path delimiter", func(t *testing.T) { - // - ts := template.Must(template.New("simple path").Parse( - `{{define "GET /{$} ListArticles(ctx)" }}{{len .}}{{end}}`, - )) - mux := http.NewServeMux() - s := new(fake.Receiver) - - err := muxt.Handlers(mux, ts, muxt.WithReceiver(s)) - require.NoError(t, err) - }) -} - -type errorWriter struct { - http.ResponseWriter -} - -func (w errorWriter) Write([]byte) (int, error) { - return 0, fmt.Errorf("banna") -}