Skip to content

Commit

Permalink
fix: not adding execute function when expected
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Aug 30, 2024
1 parent d241b9c commit c9f976e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
return err
}
out := log.New(stdout, "", 0)
s, err := muxt.Generate(templateNames, ts, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
s, err := muxt.Generate(templateNames, ts, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.outputFilename, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions example/template_routes.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"html/template"
"log"
"net/http"
"path/filepath"
"reflect"
"slices"
"strconv"
Expand Down Expand Up @@ -41,7 +42,7 @@ const (
receiverInterfaceIdent = "RoutesReceiver"
)

func Generate(templateNames []TemplateName, _ *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent string, _ *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
func Generate(templateNames []TemplateName, _ *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent, output string, fileSet *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) {
packageName = cmp.Or(packageName, defaultPackageName)
templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName)
routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName)
Expand Down Expand Up @@ -99,7 +100,11 @@ func Generate(templateNames []TemplateName, _ *template.Template, packageName, t
hasExecuteFunc := false
for _, fn := range source.IterateFunctions(templatesPackage) {
if fn.Recv == nil && fn.Name.Name == executeIdentName {
hasExecuteFunc = true
p := fileSet.Position(fn.Pos())
if filepath.Base(p.Filename) != output {
hasExecuteFunc = true
}
break
}
}
if !hasExecuteFunc {
Expand Down
52 changes: 49 additions & 3 deletions generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) {
`,
},
{
Name: "execute function defined",
Name: "execute function defined in receiver file",
Templates: `{{define "GET /age/{username} F(username)"}}Hello, {{.}}!{{end}}`,
ReceiverPackage: `s
-- receiver.go --
Expand All @@ -215,7 +215,7 @@ type T struct{}
func (*T) F(username string) int { return 30 }
func execute(response http.ResponseWriter, request *http.Request, t *template.Template, code int, data any) {
func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {
response.WriteHeader(code)
_ = templates.ExecuteTemplate(response, name, data)
}
Expand All @@ -236,6 +236,52 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) {
execute(response, request, true, "GET /age/{username} F(username)", http.StatusOK, data)
})
}
`,
},
{
Name: "execute function already defined in output file",
Templates: ``,
ReceiverPackage: `
-- receiver.go --
package main
-- template_routes.go --
package main
import(
"html/template"
"net/html"
)
func routes(mux *http.ServeMux, receiver RoutesReceiver) {}
func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {
response.WriteHeader(code)
_ = templates.ExecuteTemplate(response, name, data)
}
`,
ExpectedFile: `package main
import (
"net/http"
"bytes"
)
type RoutesReceiver interface {
}
func routes(mux *http.ServeMux, receiver RoutesReceiver) {
}
func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {
buf := bytes.NewBuffer(nil)
if err := templates.ExecuteTemplate(buf, name, data); err != nil {
http.Error(response, err.Error(), http.StatusInternalServerError)
return
}
if writeHeader {
response.WriteHeader(code)
}
_, _ = buf.WriteTo(response)
}
`,
},
{
Expand Down Expand Up @@ -1388,7 +1434,7 @@ func execute(response http.ResponseWriter, request *http.Request, writeHeader bo
logs := log.New(io.Discard, "", 0)
set := token.NewFileSet()
goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage)
out, err := muxt.Generate(templateNames, ts, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, set, goFiles, goFiles, logs)
out, err := muxt.Generate(templateNames, ts, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, muxt.DefaultOutputFileName, set, goFiles, goFiles, logs)
if tt.ExpectedError == "" {
assert.NoError(t, err)
assert.Equal(t, tt.ExpectedFile, out)
Expand Down

0 comments on commit c9f976e

Please sign in to comment.