diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index 4b8c2e9..e56d2a1 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -100,7 +100,7 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string) return err } out := log.New(stdout, "", 0) - s, err := muxt.Generate(templateNames, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out) + s, err := muxt.Generate(templateNames, ts, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.Package.Fset, g.Package.Syntax, g.Package.Syntax, out) if err != nil { return err } diff --git a/cmd/muxt/testdata/generate/form.txtar b/cmd/muxt/testdata/generate/form.txtar new file mode 100644 index 0000000..f7200ee --- /dev/null +++ b/cmd/muxt/testdata/generate/form.txtar @@ -0,0 +1,93 @@ +muxt generate --receiver-static-type=T + +cat template_routes.go + +exec go test -cover + +-- template.gohtml -- +{{define "POST / Method(form)" }}{{end}} + +-- go.mod -- +module server + +go 1.22 + +-- template.go -- +package server + +import ( + "embed" + "html/template" +) + +//go:embed *.gohtml +var formHTML embed.FS + +var templates = template.Must(template.ParseFS(formHTML, "*")) + +type Form struct { + Count []int `json:"count"` + Str string `input:"some-string" json:"str"` +} + +type T struct { + spy func(Form) Form +} + +func (t T) Method(form Form) Form { + return t.spy(form) +} +-- template_test.go -- +package server + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "slices" + "strings" + "testing" +) + +func Test(t *testing.T) { + mux := http.NewServeMux() + + var service T + + service.spy = func(form Form) Form { + if exp := []int{7, 14, 21, 29}; !slices.Equal(exp, form.Count) { + t.Errorf("exp %v, got %v", exp, form.Count) + } + if exp := "apple"; form.Str != exp { + t.Errorf("exp %v, got %v", exp, form.Str) + } + return form + } + + routes(mux, service) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(url.Values{ + "some-string": []string{"apple"}, + "Count": []string{"7", "14", "21", "29"}, + }.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + res := rec.Result() + + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + + if exp := ``; string(body) != exp { + t.Errorf("exp %v, got %v", exp, string(body)) + } +} diff --git a/generate.go b/generate.go index b304d48..f19d9dc 100644 --- a/generate.go +++ b/generate.go @@ -5,7 +5,10 @@ import ( "fmt" "go/ast" "go/token" + "html/template" "log" + "net/http" + "reflect" "slices" "strconv" "strings" @@ -21,15 +24,11 @@ const ( muxVarIdent = "mux" requestPathValue = "PathValue" - templatesLookup = "Lookup" httpRequestContextMethod = "Context" httpPackageIdent = "http" httpResponseWriterIdent = "ResponseWriter" httpServeMuxIdent = "ServeMux" httpRequestIdent = "Request" - httpStatusCode200Ident = "StatusOK" - httpStatusCode500Ident = "StatusInternalServerError" - httpStatusCode400Ident = "StatusBadRequest" httpHandleFuncIdent = "HandleFunc" contextPackageIdent = "context" @@ -42,7 +41,7 @@ const ( receiverInterfaceIdent = "RoutesReceiver" ) -func Generate(templateNames []TemplateName, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent string, _ *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) { +func Generate(templateNames []TemplateName, _ *template.Template, packageName, templatesVariableName, routesFunctionName, receiverTypeIdent string, _ *token.FileSet, receiverPackage, templatesPackage []*ast.File, log *log.Logger) (string, error) { packageName = cmp.Or(packageName, defaultPackageName) templatesVariableName = cmp.Or(templatesVariableName, DefaultTemplatesVariableName) routesFunctionName = cmp.Or(routesFunctionName, DefaultRoutesFunctionName) @@ -60,33 +59,33 @@ func Generate(templateNames []TemplateName, packageName, templatesVariableName, imports := []*ast.ImportSpec{ importSpec("net/" + httpPackageIdent), } - for _, pattern := range templateNames { + for _, name := range templateNames { var method *ast.FuncType - if pattern.fun != nil { + if name.fun != nil { for _, funcDecl := range source.IterateFunctions(receiverPackage) { - if !pattern.matchReceiver(funcDecl, receiverTypeIdent) { + if !name.matchReceiver(funcDecl, receiverTypeIdent) { continue } method = funcDecl.Type break } if method == nil { - me, im := pattern.funcType() + me, im := name.funcType() method = me imports = append(imports, im...) } receiverInterface.Methods.List = append(receiverInterface.Methods.List, &ast.Field{ - Names: []*ast.Ident{ast.NewIdent(pattern.fun.Name)}, + Names: []*ast.Ident{ast.NewIdent(name.fun.Name)}, Type: method, }) } - handlerFunc, methodImports, err := pattern.funcLit(method) + handlerFunc, methodImports, err := name.funcLit(method, receiverPackage) if err != nil { return "", err } imports = sortImports(append(imports, methodImports...)) - routes.Body.List = append(routes.Body.List, pattern.callHandleFunc(handlerFunc)) - log.Printf("%s has route for %s", routesFunctionName, pattern.String()) + routes.Body.List = append(routes.Body.List, name.callHandleFunc(handlerFunc)) + log.Printf("%s has route for %s", routesFunctionName, name.String()) } importGen := &ast.GenDecl{ Tok: token.IMPORT, @@ -129,8 +128,8 @@ func (def TemplateName) callHandleFunc(handlerFuncLit *ast.FuncLit) *ast.ExprStm }} } -func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.ImportSpec, error) { - if def.handler == "" { +func (def TemplateName) funcLit(method *ast.FuncType, files []*ast.File) (*ast.FuncLit, []*ast.ImportSpec, error) { + if method == nil { return def.httpRequestReceiverTemplateHandlerFunc(), nil, nil } lit := &ast.FuncLit{ @@ -138,16 +137,19 @@ func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.Impo Body: &ast.BlockStmt{}, } call := &ast.CallExpr{Fun: callReceiverMethod(def.fun)} - if method != nil { - if method.Params.NumFields() != len(def.call.Args) { - return nil, nil, errWrongNumberOfArguments(def, method) + if method.Params.NumFields() != len(def.call.Args) { + return nil, nil, errWrongNumberOfArguments(def, method) + } + var formStruct *ast.StructType + for pi, pt := range fieldListTypes(method.Params) { + if err := checkArgument(method, pi, def.call.Args[pi], pt, files); err != nil { + return nil, nil, err } - for pi, pt := range fieldListTypes(method.Params) { - if err := checkArgument(method, pi, def.call.Args[pi], pt); err != nil { - return nil, nil, err - } + if s, ok := findFormStruct(pt, files); ok { + formStruct = s } } + const errVarIdent = "err" var ( imports []*ast.ImportSpec writeHeader = true @@ -165,8 +167,121 @@ func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.Impo lit.Body.List = append(lit.Body.List, contextAssignment()) call.Args = append(call.Args, ast.NewIdent(TemplateNameScopeIdentifierContext)) imports = append(imports, importSpec("context")) + case TemplateNameScopeIdentifierForm: + _, tp, _ := source.FieldIndex(method.Params.List, i) + lit.Body.List = append(lit.Body.List, + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent("ParseForm"), + }, + Args: []ast.Expr{}, + }}, + formDeclaration(arg.Name, tp)) + if formStruct != nil { + for _, field := range formStruct.Fields.List { + for _, name := range field.Names { + fieldExpr := &ast.SelectorExpr{ + X: ast.NewIdent(arg.Name), + Sel: ast.NewIdent(name.Name), + } + errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(httpPackageIdent), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{ + ast.NewIdent(httpResponseField().Names[0].Name), + &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")}, + Args: []ast.Expr{}, + }, + source.HTTPStatusCode(httpPackageIdent, http.StatusBadRequest), + }, + }}, &ast.ReturnStmt{}) + + assignment := singleAssignment(token.ASSIGN, fieldExpr) + if fieldType, ok := field.Type.(*ast.ArrayType); ok { + const valVar = "val" + assignment = appendAssignment(token.ASSIGN, &ast.SelectorExpr{ + X: ast.NewIdent(arg.Name), + Sel: ast.NewIdent(name.Name), + }) + statements, parseImports, err := parseStringStatements(name.Name, errVarIdent, ast.NewIdent(valVar), fieldType.Elt, errCheck, assignment) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) + } + + forLoop := &ast.RangeStmt{ + Key: ast.NewIdent("_"), + Value: ast.NewIdent(valVar), + Tok: token.DEFINE, + X: &ast.IndexExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent("Form"), + }, + Index: &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(formInputName(field, name)), + }, + }, + Body: &ast.BlockStmt{ + List: statements, + }, + } + + lit.Body.List = append(lit.Body.List, forLoop) + imports = append(imports, parseImports...) + } else { + str := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent("FormValue"), + }, + Args: []ast.Expr{ + &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(formInputName(field, name)), + }, + }, + } + statements, parseImports, err := parseStringStatements(name.Name, errVarIdent, str, field.Type, errCheck, assignment) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate parse statements for form field %s: %w", name.Name, err) + } + lit.Body.List = append(lit.Body.List, statements...) + imports = append(imports, parseImports...) + } + } + } + } else { + imports = append(imports, importSpec("net/url")) + } + call.Args = append(call.Args, ast.NewIdent(arg.Name)) default: - statements, parseImports, err := httpPathValueAssignment(method, i, arg) + errCheck := source.ErrorCheckReturn(errVarIdent, &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(httpPackageIdent), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{ + ast.NewIdent(httpResponseField().Names[0].Name), + &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: ast.NewIdent("err"), Sel: ast.NewIdent("Error")}, + Args: []ast.Expr{}, + }, + source.HTTPStatusCode(httpPackageIdent, http.StatusBadRequest), + }, + }}, &ast.ReturnStmt{}) + src := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + } + statements, parseImports, err := httpPathValueAssignment(method, i, arg, errVarIdent, src, token.DEFINE, errCheck) if err != nil { return nil, nil, err } @@ -177,7 +292,7 @@ func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.Impo } const dataVarIdent = "data" - if method != nil && len(method.Results.List) > 1 { + if len(method.Results.List) > 1 { errVar := ast.NewIdent("err") lit.Body.List = append(lit.Body.List, @@ -200,7 +315,7 @@ func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.Impo }, Args: []ast.Expr{}, }, - httpStatusCode(httpStatusCode500Ident), + source.HTTPStatusCode(httpPackageIdent, http.StatusInternalServerError), }, }}, &ast.ReturnStmt{}, @@ -211,10 +326,22 @@ func (def TemplateName) funcLit(method *ast.FuncType) (*ast.FuncLit, []*ast.Impo } else { lit.Body.List = append(lit.Body.List, &ast.AssignStmt{Lhs: []ast.Expr{ast.NewIdent(dataVarIdent)}, Tok: token.DEFINE, Rhs: []ast.Expr{call}}) } - lit.Body.List = append(lit.Body.List, def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(dataVarIdent), writeHeader)) + lit.Body.List = append(lit.Body.List, def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(dataVarIdent), writeHeader)) return lit, imports, nil } +func formInputName(field *ast.Field, name *ast.Ident) string { + if field.Tag != nil { + v, _ := strconv.Unquote(field.Tag.Value) + tags := reflect.StructTag(v) + n, hasInputTag := tags.Lookup("input") + if hasInputTag { + return n + } + } + return name.Name +} + func (def TemplateName) funcType() (*ast.FuncType, []*ast.ImportSpec) { method := &ast.FuncType{ Params: &ast.FieldList{}, @@ -233,6 +360,8 @@ func (def TemplateName) funcType() (*ast.FuncType, []*ast.ImportSpec) { case TemplateNameScopeIdentifierContext: method.Params.List = append(method.Params.List, contextContextField()) imports = append(imports, importSpec(contextPackageIdent)) + case TemplateNameScopeIdentifierForm: + method.Params.List = append(method.Params.List, urlValuesField(arg.Name)) default: method.Params.List = append(method.Params.List, pathValueField(arg.Name)) } @@ -269,7 +398,7 @@ func errWrongNumberOfArguments(def TemplateName, method *ast.FuncType) error { return fmt.Errorf("handler %s expects %d arguments but call %s has %d", source.Format(&ast.FuncDecl{Name: ast.NewIdent(def.fun.Name), Type: method}), method.Params.NumFields(), def.handler, len(def.call.Args)) } -func checkArgument(method *ast.FuncType, argIndex int, exp ast.Expr, argType ast.Expr) error { +func checkArgument(method *ast.FuncType, argIndex int, exp ast.Expr, argType ast.Expr, files []*ast.File) error { // TODO: rewrite to "cannot use 32 (untyped int constant) as string value in argument to strings.ToUpper" arg := exp.(*ast.Ident) switch arg.Name { @@ -288,6 +417,15 @@ func checkArgument(method *ast.FuncType, argIndex int, exp ast.Expr, argType ast return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(argType), arg.Name, contextPackageIdent, contextContextTypeIdent) } return nil + case TemplateNameScopeIdentifierForm: + if matchSelectorIdents(argType, "url", "Values", false) { + return nil + } + _, ok := findFormStruct(argType, files) + if !ok { + return fmt.Errorf("method expects form to have type url.Values or T (where T is some struct type)") + } + return nil default: for paramIndex, paramType := range source.IterateFieldTypes(method.Params.List) { if argIndex != paramIndex { @@ -304,6 +442,27 @@ func checkArgument(method *ast.FuncType, argIndex int, exp ast.Expr, argType ast } } +func findFormStruct(argType ast.Expr, files []*ast.File) (*ast.StructType, bool) { + if argTypeIdent, ok := argType.(*ast.Ident); ok { + for _, file := range files { + for _, d := range file.Decls { + decl, ok := d.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { + continue + } + for _, s := range decl.Specs { + spec := s.(*ast.TypeSpec) + structType, isStruct := spec.Type.(*ast.StructType) + if isStruct && spec.Name.Name == argTypeIdent.Name { + return structType, true + } + } + } + } + } + return nil, false +} + func matchSelectorIdents(expr ast.Expr, pkg, name string, star bool) bool { if star { st, ok := expr.(*ast.StarExpr) @@ -352,6 +511,13 @@ func routesFuncType(receiverType ast.Expr) *ast.FuncType { }} } +func urlValuesField(ident string) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(ident)}, + Type: &ast.SelectorExpr{X: ast.NewIdent("url"), Sel: ast.NewIdent("Values")}, + } +} + func httpRequestField() *ast.Field { return &ast.Field{ Names: []*ast.Ident{ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest)}, @@ -371,13 +537,6 @@ func contextContextType() *ast.SelectorExpr { return &ast.SelectorExpr{X: ast.NewIdent(contextPackageIdent), Sel: ast.NewIdent(contextContextTypeIdent)} } -func httpStatusCode(name string) *ast.SelectorExpr { - return &ast.SelectorExpr{ - X: ast.NewIdent(httpPackageIdent), - Sel: ast.NewIdent(name), - } -} - func contextAssignment() *ast.AssignStmt { return &ast.AssignStmt{ Tok: token.DEFINE, @@ -391,448 +550,303 @@ func contextAssignment() *ast.AssignStmt { } } -func httpPathValueAssignment(method *ast.FuncType, i int, arg *ast.Ident) ([]ast.Stmt, []*ast.ImportSpec, error) { - const parsedVarSuffix = "Parsed" - for typeIndex, typeExp := range source.IterateFieldTypes(method.Params.List) { - if typeIndex != i { - continue - } - paramTypeIdent, ok := typeExp.(*ast.Ident) - if !ok { - return nil, nil, fmt.Errorf("unsupported type: %s", source.Format(typeExp)) - } - switch paramTypeIdent.Name { - default: - return nil, nil, fmt.Errorf("method param type %s not supported", source.Format(typeExp)) - case "bool": - errVar := ast.NewIdent("err") - - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseBool"), - }, - Args: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }}, - }}, - } - - errCheck := paramParseError(errVar) - - return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "int": - errVar := ast.NewIdent("err") - - tmp := arg.Name + parsedVarSuffix - - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseInt"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "64", Kind: token.INT}, - }, - }}, - } - - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } - - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "int16": - errVar := ast.NewIdent("err") - - tmp := arg.Name + parsedVarSuffix - - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseInt"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "16", Kind: token.INT}, - }, - }}, - } - - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } - - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "int32": - errVar := ast.NewIdent("err") - - tmp := arg.Name + parsedVarSuffix - - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseInt"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "32", Kind: token.INT}, - }, - }}, - } - - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } - - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "int8": - errVar := ast.NewIdent("err") - - tmp := arg.Name + parsedVarSuffix - - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseInt"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "8", Kind: token.INT}, - }, - }}, - } - - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } - - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "int64": - errVar := ast.NewIdent("err") - - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseInt"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, +func formDeclaration(ident string, typeExp ast.Expr) *ast.DeclStmt { + if matchSelectorIdents(typeExp, "url", "Values", false) { + return &ast.DeclStmt{ + Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(ident)}, + Type: typeExp, + Values: []ast.Expr{ + &ast.SelectorExpr{X: ast.NewIdent(httpResponseField().Names[0].Name), Sel: ast.NewIdent("Form")}, }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "64", Kind: token.INT}, }, - }}, - } + }, + }, + } + } - errCheck := paramParseError(errVar) + return &ast.DeclStmt{ + Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(ident)}, + Type: typeExp, + }, + }, + }, + } +} - return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "uint": - errVar := ast.NewIdent("err") +func httpPathValueAssignment(method *ast.FuncType, i int, arg *ast.Ident, errVarIdent string, str ast.Expr, assignTok token.Token, errCheck *ast.IfStmt) ([]ast.Stmt, []*ast.ImportSpec, error) { + for typeIndex, typeExp := range source.IterateFieldTypes(method.Params.List) { + if typeIndex != i { + continue + } + assignment := singleAssignment(assignTok, ast.NewIdent(arg.Name)) + return parseStringStatements(arg.Name, errVarIdent, str, typeExp, errCheck, assignment) + } + return nil, nil, fmt.Errorf("type for argumement %d not found", i) +} - tmp := arg.Name + parsedVarSuffix +func singleAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) ast.Stmt { + return func(exp ast.Expr) ast.Stmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{result}, + Tok: assignTok, + Rhs: []ast.Expr{exp}, + } + } +} - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseUint"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "64", Kind: token.INT}, - }, - }}, - } +func appendAssignment(assignTok token.Token, result ast.Expr) func(exp ast.Expr) ast.Stmt { + return func(exp ast.Expr) ast.Stmt { + return &ast.AssignStmt{ + Lhs: []ast.Expr{result}, + Tok: assignTok, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent("append"), + Args: []ast.Expr{result, exp}, + }}, + } + } +} - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } +func parseStringStatements(name string, errVarIdent string, str, typeExp ast.Expr, errCheck *ast.IfStmt, assignment func(ast.Expr) ast.Stmt) ([]ast.Stmt, []*ast.ImportSpec, error) { + const parsedVarSuffix = "Parsed" + paramTypeIdent, ok := typeExp.(*ast.Ident) + if !ok { + return nil, nil, fmt.Errorf("unsupported type: %s", source.Format(typeExp)) + } + base10 := source.Int(10) + switch paramTypeIdent.Name { + default: + return nil, nil, fmt.Errorf("method param type %s not supported", source.Format(typeExp)) + case "bool": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseBool"), + }, + Args: []ast.Expr{str}, + }}, + } - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "uint16": - errVar := ast.NewIdent("err") + assign := assignment(ast.NewIdent(tmp)) - tmp := arg.Name + parsedVarSuffix + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int": + tmp := name + parsedVarSuffix - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseUint"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "16", Kind: token.INT}, - }, - }}, - } + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("Atoi"), + }, + Args: []ast.Expr{str}, + }}, + } - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } + assign := assignment(ast.NewIdent(tmp)) - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "uint32": - errVar := ast.NewIdent("err") + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int16": + tmp := name + parsedVarSuffix - tmp := arg.Name + parsedVarSuffix + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{str, base10, source.Int(16)}, + }}, + } - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseUint"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "32", Kind: token.INT}, - }, - }}, - } + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int32": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{str, base10, source.Int(32)}, + }}, + } - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int8": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{str, base10, source.Int(8)}, + }}, + } - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "uint64": + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int64": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{str, base10, source.Int(64)}, + }}, + } - errVar := ast.NewIdent("err") + assign := assignment(ast.NewIdent(tmp)) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseUint"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "64", Kind: token.INT}, - }, - }}, - } + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint": + tmp := name + parsedVarSuffix - errCheck := paramParseError(errVar) + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{str, base10, source.Int(64)}, + }}, + } - return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "uint8": - errVar := ast.NewIdent("err") + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint16": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{str, base10, source.Int(16)}, + }}, + } - tmp := arg.Name + parsedVarSuffix + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint32": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{str, base10, source.Int(32)}, + }}, + } - parse := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("strconv"), - Sel: ast.NewIdent("ParseUint"), - }, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }, - &ast.BasicLit{Value: "10", Kind: token.INT}, - &ast.BasicLit{Value: "8", Kind: token.INT}, - }, - }}, - } + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint64": + tmp := name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{str, base10, source.Int(64)}, + }}, + } - errCheck := paramParseError(errVar) - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: ast.NewIdent(paramTypeIdent.Name), - Args: []ast.Expr{ast.NewIdent(tmp)}, - }}, - } + assign := assignment(ast.NewIdent(tmp)) - return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil - case "string": - assign := &ast.AssignStmt{ - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, - }}, - } + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint8": + tmp := name + parsedVarSuffix - return []ast.Stmt{assign}, nil, nil + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVarIdent)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{str, base10, source.Int(8)}, + }}, } - } - return nil, nil, fmt.Errorf("type for argumement %d not found", i) -} -func paramParseError(errVar *ast.Ident) *ast.IfStmt { - return &ast.IfStmt{ - Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")}, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.ExprStmt{X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(httpPackageIdent), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{ - ast.NewIdent(httpResponseField().Names[0].Name), - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent("err"), - Sel: ast.NewIdent("Error"), - }, - Args: []ast.Expr{}, - }, - httpStatusCode(httpStatusCode400Ident), - }, - }}, - &ast.ReturnStmt{}, - }, - }, + assign := assignment(&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }) + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "string": + assign := assignment(str) + return []ast.Stmt{assign}, nil, nil } } @@ -853,7 +867,7 @@ func (def TemplateName) executeCall(status, data ast.Expr, writeHeader bool) *as func (def TemplateName) httpRequestReceiverTemplateHandlerFunc() *ast.FuncLit { return &ast.FuncLit{ Type: httpHandlerFuncType(), - Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(httpStatusCode(httpStatusCode200Ident), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, + Body: &ast.BlockStmt{List: []ast.Stmt{def.executeCall(source.HTTPStatusCode(httpPackageIdent, http.StatusOK), ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), true)}}, } } @@ -932,7 +946,7 @@ func executeFuncDecl(templatesVariableIdent string) *ast.FuncDecl { }, Args: []ast.Expr{}, }, - httpStatusCode(httpStatusCode500Ident), + source.HTTPStatusCode(httpPackageIdent, http.StatusInternalServerError), }, }}, &ast.ReturnStmt{}, diff --git a/generate_internal_test.go b/generate_internal_test.go deleted file mode 100644 index 4c32560..0000000 --- a/generate_internal_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package muxt - -import ( - "go/ast" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/crhntr/muxt/internal/source" -) - -func TestTemplateName_funcLit(t *testing.T) { - for _, tt := range []struct { - Name string - In string - Out string - Imports []string - Method *ast.FuncType - }{ - { - Name: "get", - In: "GET /", - Out: `func(response http.ResponseWriter, request *http.Request) { - execute(response, request, true, "GET /", http.StatusOK, request) -}`, - }, - { - Name: "call F", - In: "GET / F()", - Out: `func(response http.ResponseWriter, request *http.Request) { - data := receiver.F() - execute(response, request, true, "GET / F()", http.StatusOK, data) -}`, - }, - { - Name: "call F with argument request", - In: "GET / F(request)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Type: httpRequestField().Type}}}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - Out: `func(response http.ResponseWriter, request *http.Request) { - data := receiver.F(request) - execute(response, request, true, "GET / F(request)", http.StatusOK, data) -}`, - }, - { - Name: "call F with argument response", - In: "GET / F(response)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Type: httpResponseField().Type, Names: []*ast.Ident{{Name: "res"}}}}}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - Out: `func(response http.ResponseWriter, request *http.Request) { - data := receiver.F(response) - execute(response, request, false, "GET / F(response)", http.StatusOK, data) -}`, - }, - { - Name: "call F with argument context", - In: "GET / F(ctx)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Type: contextContextField().Type, Names: []*ast.Ident{{Name: "reqCtx"}}}}}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - Out: `func(response http.ResponseWriter, request *http.Request) { - ctx := request.Context() - data := receiver.F(ctx) - execute(response, request, true, "GET / F(ctx)", http.StatusOK, data) -}`, - }, - { - Name: "call F with argument path param", - In: "GET /{param} F(param)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("string")}}}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - Out: `func(response http.ResponseWriter, request *http.Request) { - param := request.PathValue("param") - data := receiver.F(param) - execute(response, request, true, "GET /{param} F(param)", http.StatusOK, data) -}`, - }, - { - Name: "call F with multiple arguments", - In: "GET /{userName} F(ctx, userName)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: contextContextField().Type, Names: []*ast.Ident{{Name: "ctx"}}}, - {Type: ast.NewIdent("string"), Names: []*ast.Ident{{Name: "n"}}}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - Out: `func(response http.ResponseWriter, request *http.Request) { - ctx := request.Context() - userName := request.PathValue("userName") - data := receiver.F(ctx, userName) - execute(response, request, true, "GET /{userName} F(ctx, userName)", http.StatusOK, data) -}`, - }, - } { - t.Run(tt.Name, func(t *testing.T) { - pat, err, ok := NewTemplateName(tt.In) - require.True(t, ok) - require.NoError(t, err) - out, _, err := pat.funcLit(tt.Method) - require.NoError(t, err) - assert.Equal(t, tt.Out, source.Format(out)) - }) - } -} - -func TestTemplateName_HandlerFuncLit_err(t *testing.T) { - for _, tt := range []struct { - Name string - In string - ErrSub string - Method *ast.FuncType - }{ - { - Name: "missing arguments", - In: "GET / F()", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("string")}}}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "handler func F(string) any expects 1 arguments but call F() has 0", - }, - { - Name: "extra arguments", - In: "GET /{name} F(ctx, name)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: &ast.SelectorExpr{X: ast.NewIdent(contextPackageIdent), Sel: ast.NewIdent(contextContextTypeIdent)}}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "handler func F(context.Context) any expects 1 arguments but call F(ctx, name) has 2", - }, - { - Name: "wrong argument type request", - In: "GET / F(request)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: ast.NewIdent("string")}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method expects type string but request is *http.Request", - }, - { - Name: "wrong argument type ctx", - In: "GET / F(ctx)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: ast.NewIdent("string")}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method expects type string but ctx is context.Context", - }, - { - Name: "wrong argument type response", - In: "GET / F(response)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: ast.NewIdent("string")}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method expects type string but response is http.ResponseWriter", - }, - { - Name: "wrong argument type path value", - In: "GET /{name} F(name)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: ast.NewIdent("float64")}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method param type float64 not supported", - }, - { - Name: "wrong argument type request ptr", - In: "GET / F(request)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: &ast.StarExpr{X: ast.NewIdent("T")}}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method expects type *T but request is *http.Request", - }, - { - Name: "wrong argument type in field list", - In: "GET /post/{postID}/comment/{commentID} F(ctx, request, commentID)", - Method: &ast.FuncType{ - Params: &ast.FieldList{List: []*ast.Field{ - {Type: contextContextField().Type, Names: []*ast.Ident{{Name: "ctx"}}}, - {Names: []*ast.Ident{ast.NewIdent("postID"), ast.NewIdent("commentID")}, Type: ast.NewIdent("string")}, - }}, - Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, - }, - ErrSub: "method expects type string but request is *http.Request", - }, - } { - t.Run(tt.Name, func(t *testing.T) { - pat, err, ok := NewTemplateName(tt.In) - require.True(t, ok) - require.NoError(t, err) - _, _, err = pat.funcLit(tt.Method) - assert.ErrorContains(t, err, tt.ErrSub) - }) - } -} diff --git a/generate_test.go b/generate_test.go index d34e5e3..de7c021 100644 --- a/generate_test.go +++ b/generate_test.go @@ -26,7 +26,6 @@ func TestGenerate(t *testing.T) { TemplatesVar string RoutesFunc string Imports []string - Method *ast.FuncType ExpectedError string ExpectedFile string @@ -185,14 +184,12 @@ package main type T struct{} func (*T) F(username string) int { return 30 } -`, + +` + executeGo, Receiver: "T", ExpectedFile: `package main -import ( - "net/http" - "bytes" -) +import "net/http" type RoutesReceiver interface { F(username string) int @@ -205,17 +202,6 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /age/{username} F(username)", http.StatusOK, data) }) } -func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) { - buf := bytes.NewBuffer(nil) - if err := templates.ExecuteTemplate(buf, name, data); err != nil { - http.Error(response, err.Error(), http.StatusInternalServerError) - return - } - if writeHeader { - response.WriteHeader(code) - } - _, _ = buf.WriteTo(response) -} `, }, { @@ -259,17 +245,17 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { -- receiver.go -- package main +import "net/http" + type T struct{} func (T) F(username string) (int, error) { return 30, error } -`, + +` + executeGo, Receiver: "T", ExpectedFile: `package main -import ( - "net/http" - "bytes" -) +import "net/http" type RoutesReceiver interface { F(username string) (int, error) @@ -286,17 +272,6 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /age/{username} F(username)", http.StatusOK, data) }) } -func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) { - buf := bytes.NewBuffer(nil) - if err := templates.ExecuteTemplate(buf, name, data); err != nil { - http.Error(response, err.Error(), http.StatusInternalServerError) - return - } - if writeHeader { - response.WriteHeader(code) - } - _, _ = buf.WriteTo(response) -} `, }, { @@ -309,7 +284,8 @@ package main type T struct{} func (T) F(ctx context.Context) int { return 30 } -`, + +` + executeGo, Receiver: "T", ExpectedError: "method expects type context.Context but request is *http.Request", }, @@ -338,14 +314,14 @@ func (T0) F(ctx context.Context) int { return 30 } func (T) F1(ctx context.Context, username string) int { return 30 } func (T) F(ctx context.Context, username string) int { return 30 } -`, + +` + executeGo, Receiver: "T", ExpectedFile: `package main import ( "context" "net/http" - "bytes" ) type RoutesReceiver interface { @@ -360,17 +336,6 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /age/{username} F(ctx, username)", http.StatusOK, data) }) } -func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) { - buf := bytes.NewBuffer(nil) - if err := templates.ExecuteTemplate(buf, name, data); err != nil { - http.Error(response, err.Error(), http.StatusInternalServerError) - return - } - if writeHeader { - response.WriteHeader(code) - } - _, _ = buf.WriteTo(response) -} `, }, { @@ -418,13 +383,12 @@ func (T) PassUint8(in uint8) uint8 { return in } func (T) PassBool(in bool) bool { return in } func (T) PassByte(in byte) byte { return in } func (T) PassRune(in rune) rune { return in } -`, +` + executeGo, ExpectedFile: `package main import ( "net/http" "strconv" - "bytes" ) type RoutesReceiver interface { @@ -443,21 +407,22 @@ type RoutesReceiver interface { func routes(mux *http.ServeMux, receiver RoutesReceiver) { mux.HandleFunc("GET /bool/{value}", func(response http.ResponseWriter, request *http.Request) { - value, err := strconv.ParseBool(request.PathValue("value")) + valueParsed, err := strconv.ParseBool(request.PathValue("value")) if err != nil { http.Error(response, err.Error(), http.StatusBadRequest) return } + value := valueParsed data := receiver.PassBool(value) execute(response, request, true, "GET /bool/{value} PassBool(value)", http.StatusOK, data) }) mux.HandleFunc("GET /int/{value}", func(response http.ResponseWriter, request *http.Request) { - valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 64) + valueParsed, err := strconv.Atoi(request.PathValue("value")) if err != nil { http.Error(response, err.Error(), http.StatusBadRequest) return } - value := int(valueParsed) + value := valueParsed data := receiver.PassInt(value) execute(response, request, true, "GET /int/{value} PassInt(value)", http.StatusOK, data) }) @@ -482,11 +447,12 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /int32/{value} PassInt32(value)", http.StatusOK, data) }) mux.HandleFunc("GET /int64/{value}", func(response http.ResponseWriter, request *http.Request) { - value, err := strconv.ParseInt(request.PathValue("value"), 10, 64) + valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 64) if err != nil { http.Error(response, err.Error(), http.StatusBadRequest) return } + value := valueParsed data := receiver.PassInt64(value) execute(response, request, true, "GET /int64/{value} PassInt64(value)", http.StatusOK, data) }) @@ -531,11 +497,12 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /uint32/{value} PassUint32(value)", http.StatusOK, data) }) mux.HandleFunc("GET /uint64/{value}", func(response http.ResponseWriter, request *http.Request) { - value, err := strconv.ParseUint(request.PathValue("value"), 10, 64) + valueParsed, err := strconv.ParseUint(request.PathValue("value"), 10, 64) if err != nil { http.Error(response, err.Error(), http.StatusBadRequest) return } + value := valueParsed data := receiver.PassUint64(value) execute(response, request, true, "GET /uint64/{value} PassUint64(value)", http.StatusOK, data) }) @@ -550,46 +517,904 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) { execute(response, request, true, "GET /uint8/{value} PassUint8(value)", http.StatusOK, data) }) } -func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) { - buf := bytes.NewBuffer(nil) - if err := templates.ExecuteTemplate(buf, name, data); err != nil { - http.Error(response, err.Error(), http.StatusInternalServerError) - return - } - if writeHeader { - response.WriteHeader(code) +`, + }, + { + Name: "form has no fields", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type T struct{} + +type In struct{} + +func (T) F(form In) any { return nil } + +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(form In) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is not defined and a form field is passed", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type T struct{} + +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import ( + "net/http" + "net/url" +) + +type RoutesReceiver interface { + F(form url.Values) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form url.Values = response.Form + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is defined and form type is a struct", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type ( + T struct{} + In struct{ + field string } - _, _ = buf.WriteTo(response) +) + +func (T) F(form In) int { return 0 } + +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(form In) int +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + form.field = request.FormValue("field") + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) } `, }, - } { - t.Run(tt.Name, func(t *testing.T) { - ts := template.Must(template.New(tt.Name).Parse(tt.Templates)) - templateNames, err := muxt.TemplateNames(ts) - require.NoError(t, err) - logs := log.New(io.Discard, "", 0) - set := token.NewFileSet() - goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage) - out, err := muxt.Generate(templateNames, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, set, goFiles, goFiles, logs) - if tt.ExpectedError == "" { - assert.NoError(t, err) - assert.Equal(t, tt.ExpectedFile, out) - } else { - assert.ErrorContains(t, err, tt.ExpectedError) - } - }) + { + Name: "F is defined and form field has an input tag", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type ( + T struct{} + In struct{ + field string ` + "`input:\"some-field\"`" + ` } +) + +func (T) F(form In) int { return 0 } +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(form In) int } -func methodFuncTypeLoader(t *testing.T, set *token.FileSet, in string) []*ast.File { - t.Helper() - archive := txtar.Parse([]byte(in)) - var files []*ast.File - for _, file := range archive.Files { - f, err := parser.ParseFile(set, file.Name, file.Data, parser.AllErrors) - require.NoError(t, err) - files = append(files, f) +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + form.field = request.FormValue("some-field") + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is defined and form has two string fields", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +import "net/http" + +type ( + T struct{} + In struct{ + fieldInt int + fieldInt64 int64 + fieldInt32 int32 + fieldInt16 int16 + fieldInt8 int8 + fieldUint uint + fieldUint64 uint64 + fieldUint16 uint16 + fieldUint32 uint32 + fieldUint16 uint16 + fieldUint8 uint8 + fieldBool bool } - return files +) + +func (T) F(form In) int { return 0 } + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + Receiver: "T", + ExpectedFile: `package main + +import ( + "net/http" + "strconv" +) + +type RoutesReceiver interface { + F(form In) int +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + fieldIntParsed, err := strconv.Atoi(request.FormValue("fieldInt")) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt = fieldIntParsed + fieldInt64Parsed, err := strconv.ParseInt(request.FormValue("fieldInt64"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt64 = fieldInt64Parsed + fieldInt32Parsed, err := strconv.ParseInt(request.FormValue("fieldInt32"), 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt32 = int32(fieldInt32Parsed) + fieldInt16Parsed, err := strconv.ParseInt(request.FormValue("fieldInt16"), 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt16 = int16(fieldInt16Parsed) + fieldInt8Parsed, err := strconv.ParseInt(request.FormValue("fieldInt8"), 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt8 = int8(fieldInt8Parsed) + fieldUintParsed, err := strconv.ParseUint(request.FormValue("fieldUint"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint = uint(fieldUintParsed) + fieldUint64Parsed, err := strconv.ParseUint(request.FormValue("fieldUint64"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint64 = fieldUint64Parsed + fieldUint16Parsed, err := strconv.ParseUint(request.FormValue("fieldUint16"), 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint16 = uint16(fieldUint16Parsed) + fieldUint32Parsed, err := strconv.ParseUint(request.FormValue("fieldUint32"), 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint32 = uint32(fieldUint32Parsed) + fieldUint16Parsed, err := strconv.ParseUint(request.FormValue("fieldUint16"), 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint16 = uint16(fieldUint16Parsed) + fieldUint8Parsed, err := strconv.ParseUint(request.FormValue("fieldUint8"), 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint8 = uint8(fieldUint8Parsed) + fieldBoolParsed, err := strconv.ParseBool(request.FormValue("fieldBool")) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldBool = fieldBoolParsed + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) } +`, + }, + { + Name: "F is defined and form has two two names for a single field", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +import "net/http" + +type ( + T struct{} + In struct{ + field1, field2 string + } +) + +func (T) F(form In) int { return 0 } + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + Receiver: "T", + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(form In) int +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + form.field1 = request.FormValue("field1") + form.field2 = request.FormValue("field2") + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is defined and form slice field", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +import "net/http" + +type ( + T struct{} + In struct{ + field []string + } +) + +func (T) F(form In) int { return 0 } + +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(form In) int +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + for _, val := range request.Form["field"] { + form.field = append(form.field, val) + } + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is defined and form has typed slice fields", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type ( + T struct{} + In struct{ + fieldInt []int + fieldInt64 []int64 + fieldInt32 []int32 + fieldInt16 []int16 + fieldInt8 []int8 + fieldUint []uint + fieldUint64 []uint64 + fieldUint16 []uint16 + fieldUint32 []uint32 + fieldUint16 []uint16 + fieldUint8 []uint8 + fieldBool []bool + } +) + +func (T) F(form In) int { return 0 } + +` + executeGo, + Receiver: "T", + ExpectedFile: `package main + +import ( + "net/http" + "strconv" +) + +type RoutesReceiver interface { + F(form In) int +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + request.ParseForm() + var form In + for _, val := range request.Form["fieldInt"] { + fieldIntParsed, err := strconv.Atoi(val) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt = append(form.fieldInt, fieldIntParsed) + } + for _, val := range request.Form["fieldInt64"] { + fieldInt64Parsed, err := strconv.ParseInt(val, 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt64 = append(form.fieldInt64, fieldInt64Parsed) + } + for _, val := range request.Form["fieldInt32"] { + fieldInt32Parsed, err := strconv.ParseInt(val, 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt32 = append(form.fieldInt32, int32(fieldInt32Parsed)) + } + for _, val := range request.Form["fieldInt16"] { + fieldInt16Parsed, err := strconv.ParseInt(val, 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt16 = append(form.fieldInt16, int16(fieldInt16Parsed)) + } + for _, val := range request.Form["fieldInt8"] { + fieldInt8Parsed, err := strconv.ParseInt(val, 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldInt8 = append(form.fieldInt8, int8(fieldInt8Parsed)) + } + for _, val := range request.Form["fieldUint"] { + fieldUintParsed, err := strconv.ParseUint(val, 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint = append(form.fieldUint, uint(fieldUintParsed)) + } + for _, val := range request.Form["fieldUint64"] { + fieldUint64Parsed, err := strconv.ParseUint(val, 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint64 = append(form.fieldUint64, fieldUint64Parsed) + } + for _, val := range request.Form["fieldUint16"] { + fieldUint16Parsed, err := strconv.ParseUint(val, 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint16 = append(form.fieldUint16, uint16(fieldUint16Parsed)) + } + for _, val := range request.Form["fieldUint32"] { + fieldUint32Parsed, err := strconv.ParseUint(val, 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint32 = append(form.fieldUint32, uint32(fieldUint32Parsed)) + } + for _, val := range request.Form["fieldUint16"] { + fieldUint16Parsed, err := strconv.ParseUint(val, 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint16 = append(form.fieldUint16, uint16(fieldUint16Parsed)) + } + for _, val := range request.Form["fieldUint8"] { + fieldUint8Parsed, err := strconv.ParseUint(val, 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldUint8 = append(form.fieldUint8, uint8(fieldUint8Parsed)) + } + for _, val := range request.Form["fieldBool"] { + fieldBoolParsed, err := strconv.ParseBool(val) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + form.fieldBool = append(form.fieldBool, fieldBoolParsed) + } + data := receiver.F(form) + execute(response, request, true, "GET / F(form)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "F is defined and form has unsupported field type", + Templates: `{{define "GET / F(form)"}}Hello, {{.}}!{{end}}`, + ReceiverPackage: ` +-- in.go -- +package main + +type ( + T struct{} + In struct{ + ts time.Time + } +) + +func (T) F(form In) int { return 0 } + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + Receiver: "T", + ExpectedError: "failed to generate parse statements for form field ts: unsupported type: time.Time", + }, + { + Name: "call F", + Templates: `{{define "GET / F()"}}Hello, world!{{end}}`, + Receiver: "T", + PackageName: "main", + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F() any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + data := receiver.F() + execute(response, request, true, "GET / F()", http.StatusOK, data) + }) +} +`, + }, + { + Name: "no handler", + Templates: `{{define "GET /"}}Hello, world!{{end}}`, + Receiver: "T", + PackageName: "main", + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + execute(response, request, true, "GET /", http.StatusOK, request) + }) +} +`, + }, + { + Name: "no handler", + Templates: `{{define "GET /"}}Hello, world!{{end}}`, + Receiver: "T", + PackageName: "main", + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + execute(response, request, true, "GET /", http.StatusOK, request) + }) +} +`, + }, + { + Name: "call F with argument response", + Templates: `{{define "GET / F(response)"}}{{end}}`, + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func (T) F(http.ResponseWriter) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(response http.ResponseWriter) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + data := receiver.F(response) + execute(response, request, false, "GET / F(response)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "call F with argument context", + Templates: `{{define "GET / F(ctx)"}}{{end}}`, + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func (T) F(ctx context.Context) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import ( + "context" + "net/http" +) + +type RoutesReceiver interface { + F(ctx context.Context) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /", func(response http.ResponseWriter, request *http.Request) { + ctx := request.Context() + data := receiver.F(ctx) + execute(response, request, true, "GET / F(ctx)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "call F with argument path param", + Templates: `{{define "GET /{param} F(param)"}}{{end}}`, + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func (T) F(param string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import "net/http" + +type RoutesReceiver interface { + F(param string) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /{param}", func(response http.ResponseWriter, request *http.Request) { + param := request.PathValue("param") + data := receiver.F(param) + execute(response, request, true, "GET /{param} F(param)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "call F with multiple arguments", + Templates: `{{define "GET /{userName} F(ctx, userName)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import "context" + +type T struct{} + +func (T) F(ctx context.Context, userName string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedFile: `package main + +import ( + "context" + "net/http" +) + +type RoutesReceiver interface { + F(ctx context.Context, userName string) any +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /{userName}", func(response http.ResponseWriter, request *http.Request) { + ctx := request.Context() + userName := request.PathValue("userName") + data := receiver.F(ctx, userName) + execute(response, request, true, "GET /{userName} F(ctx, userName)", http.StatusOK, data) + }) +} +`, + }, + { + Name: "missing arguments", + Templates: `{{define "GET / F()"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +type T struct{} + +func (T) F(string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + + ExpectedError: "handler func F(string) any expects 1 arguments but call F() has 0", + }, + { + Name: "extra arguments", + Templates: `{{define "GET /{name} F(ctx, name)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import ( + "context" + "net/html" +) + +type T struct{} + +func (T) F(context.Context) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "handler func F(context.Context) any expects 1 arguments but call F(ctx, name) has 2", + }, + { + Name: "wrong argument type request", + Templates: `{{define "GET / F(request)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import ( + "context" + "net/html" +) + +type T struct{} + +func (T) F(string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method expects type string but request is *http.Request", + }, + { + Name: "wrong argument type ctx", + Templates: `{{define "GET / F(ctx)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import "net/html" + +type T struct{} + +func (T) F(string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method expects type string but ctx is context.Context", + }, + { + Name: "wrong argument type response", + Templates: `{{define "GET / F(response)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import "net/html" + +type T struct{} + +func (T) F(string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method expects type string but response is http.ResponseWriter", + }, + { + Name: "wrong argument type path value", + Templates: `{{define "GET /{name} F(name)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import "net/html" + +type T struct{} + +func (T) F(float64) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method param type float64 not supported", + }, + { + Name: "wrong argument type request ptr", + Templates: `{{define "GET / F(request)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import "net/html" + +type T struct{} + +func (T) F(*T) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method expects type *T but request is *http.Request", + }, + { + Name: "wrong argument type in field list", + Templates: `{{define "GET /post/{postID}/comment/{commentID} F(ctx, request, commentID)"}}{{end}}`, + Receiver: "T", + ReceiverPackage: `-- in.go -- +package main + +import ( + "context" + "net/html" +) + +type T struct{} + +func (T) F(context.Context, string, string) any {return nil} + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +`, + ExpectedError: "method expects type string but request is *http.Request", + }, + } { + t.Run(tt.Name, func(t *testing.T) { + ts := template.Must(template.New(tt.Name).Parse(tt.Templates)) + templateNames, err := muxt.TemplateNames(ts) + require.NoError(t, err) + logs := log.New(io.Discard, "", 0) + set := token.NewFileSet() + goFiles := methodFuncTypeLoader(t, set, tt.ReceiverPackage) + out, err := muxt.Generate(templateNames, ts, tt.PackageName, tt.TemplatesVar, tt.RoutesFunc, tt.Receiver, set, goFiles, goFiles, logs) + if tt.ExpectedError == "" { + assert.NoError(t, err) + assert.Equal(t, tt.ExpectedFile, out) + } else { + assert.ErrorContains(t, err, tt.ExpectedError) + } + }) + } +} + +func methodFuncTypeLoader(t *testing.T, set *token.FileSet, in string) []*ast.File { + t.Helper() + archive := txtar.Parse([]byte(in)) + var files []*ast.File + for _, file := range archive.Files { + f, err := parser.ParseFile(set, file.Name, file.Data, parser.AllErrors) + require.NoError(t, err) + files = append(files, f) + } + return files +} + +const executeGo = `-- execute.go -- +package main + +import "net/http" + +func execute(response http.ResponseWriter, request *http.Request, writeHeader bool, name string, code int, data any) {} +` diff --git a/internal/source/ast.go b/internal/source/ast.go deleted file mode 100644 index 133f266..0000000 --- a/internal/source/ast.go +++ /dev/null @@ -1,143 +0,0 @@ -package source - -import ( - "fmt" - "go/ast" - "go/printer" - "go/token" - "strconv" - "strings" -) - -func IterateGenDecl(files []*ast.File, tok token.Token) func(func(*ast.File, *ast.GenDecl) bool) { - return func(yield func(*ast.File, *ast.GenDecl) bool) { - for _, file := range files { - for _, decl := range file.Decls { - d, ok := decl.(*ast.GenDecl) - if !ok || d.Tok != tok { - continue - } - if !yield(file, d) { - return - } - } - } - } -} - -func IterateValueSpecs(files []*ast.File) func(func(*ast.File, *ast.ValueSpec) bool) { - return func(yield func(*ast.File, *ast.ValueSpec) bool) { - for file, decl := range IterateGenDecl(files, token.VAR) { - for _, s := range decl.Specs { - if !yield(file, s.(*ast.ValueSpec)) { - return - } - } - } - } -} - -//func IterateTypes(files []*ast.File) func(func(*ast.File, *ast.TypeSpec) bool) { -// return func(yield func(*ast.File, *ast.TypeSpec) bool) { -// for _, file := range files { -// for _, decl := range file.Decls { -// spec, ok := decl.(*ast.GenDecl) -// if !ok || spec.Tok != token.TYPE { -// continue -// } -// for _, s := range spec.Specs { -// t, ok := s.(*ast.TypeSpec) -// if !ok { -// continue -// } -// if !yield(file, t) { -// return -// } -// } -// } -// } -// } -//} - -func IterateFunctions(files []*ast.File) func(func(*ast.File, *ast.FuncDecl) bool) { - return func(yield func(*ast.File, *ast.FuncDecl) bool) { - for _, file := range files { - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - if !yield(file, fn) { - return - } - } - } - } -} - -//func IterateImports(files []*ast.File) func(func(*ast.File, *ast.ImportSpec) bool) { -// return func(yield func(*ast.File, *ast.ImportSpec) bool) { -// for _, file := range files { -// for _, decl := range file.Decls { -// genDecl, ok := decl.(*ast.GenDecl) -// if !ok || genDecl.Tok != token.IMPORT { -// continue -// } -// for _, s := range genDecl.Specs { -// if !yield(file, s.(*ast.ImportSpec)) { -// return -// } -// } -// } -// } -// } -//} - -func Format(node ast.Node) string { - var buf strings.Builder - if err := printer.Fprint(&buf, token.NewFileSet(), node); err != nil { - return fmt.Sprintf("formatting error: %v", err) - } - return buf.String() -} - -func evaluateStringLiteralExpressionList(wd string, set *token.FileSet, list []ast.Expr) ([]string, error) { - result := make([]string, 0, len(list)) - for _, a := range list { - s, err := evaluateStringLiteralExpression(wd, set, a) - if err != nil { - return result, err - } - result = append(result, s) - } - return result, nil -} - -func evaluateStringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr) (string, error) { - arg, ok := exp.(*ast.BasicLit) - if !ok || arg.Kind != token.STRING { - return "", contextError(wd, set, exp.Pos(), fmt.Errorf("expected string literal got %s", Format(exp))) - } - return strconv.Unquote(arg.Value) -} - -func IterateFieldTypes(list []*ast.Field) func(func(int, ast.Expr) bool) { - return func(yield func(int, ast.Expr) bool) { - i := 0 - for _, field := range list { - if len(field.Names) == 0 { - if !yield(i, field.Type) { - return - } - i++ - } else { - for range field.Names { - if !yield(i, field.Type) { - return - } - i++ - } - } - } - } -} diff --git a/internal/source/go.go b/internal/source/go.go new file mode 100644 index 0000000..ef78702 --- /dev/null +++ b/internal/source/go.go @@ -0,0 +1,247 @@ +package source + +import ( + "fmt" + "go/ast" + "go/printer" + "go/token" + "net/http" + "strconv" + "strings" +) + +func IterateGenDecl(files []*ast.File, tok token.Token) func(func(*ast.File, *ast.GenDecl) bool) { + return func(yield func(*ast.File, *ast.GenDecl) bool) { + for _, file := range files { + for _, decl := range file.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != tok { + continue + } + if !yield(file, d) { + return + } + } + } + } +} + +func IterateValueSpecs(files []*ast.File) func(func(*ast.File, *ast.ValueSpec) bool) { + return func(yield func(*ast.File, *ast.ValueSpec) bool) { + for file, decl := range IterateGenDecl(files, token.VAR) { + for _, s := range decl.Specs { + if !yield(file, s.(*ast.ValueSpec)) { + return + } + } + } + } +} + +//func IterateTypes(files []*ast.File) func(func(*ast.File, *ast.TypeSpec) bool) { +// return func(yield func(*ast.File, *ast.TypeSpec) bool) { +// for _, file := range files { +// for _, decl := range file.Decls { +// spec, ok := decl.(*ast.GenDecl) +// if !ok || spec.Tok != token.TYPE { +// continue +// } +// for _, s := range spec.Specs { +// t, ok := s.(*ast.TypeSpec) +// if !ok { +// continue +// } +// if !yield(file, t) { +// return +// } +// } +// } +// } +// } +//} + +func IterateFunctions(files []*ast.File) func(func(*ast.File, *ast.FuncDecl) bool) { + return func(yield func(*ast.File, *ast.FuncDecl) bool) { + for _, file := range files { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if !yield(file, fn) { + return + } + } + } + } +} + +//func IterateImports(files []*ast.File) func(func(*ast.File, *ast.ImportSpec) bool) { +// return func(yield func(*ast.File, *ast.ImportSpec) bool) { +// for _, file := range files { +// for _, decl := range file.Decls { +// genDecl, ok := decl.(*ast.GenDecl) +// if !ok || genDecl.Tok != token.IMPORT { +// continue +// } +// for _, s := range genDecl.Specs { +// if !yield(file, s.(*ast.ImportSpec)) { +// return +// } +// } +// } +// } +// } +//} + +func Format(node ast.Node) string { + var buf strings.Builder + if err := printer.Fprint(&buf, token.NewFileSet(), node); err != nil { + return fmt.Sprintf("formatting error: %v", err) + } + return buf.String() +} + +func evaluateStringLiteralExpressionList(wd string, set *token.FileSet, list []ast.Expr) ([]string, error) { + result := make([]string, 0, len(list)) + for _, a := range list { + s, err := evaluateStringLiteralExpression(wd, set, a) + if err != nil { + return result, err + } + result = append(result, s) + } + return result, nil +} + +func evaluateStringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr) (string, error) { + arg, ok := exp.(*ast.BasicLit) + if !ok || arg.Kind != token.STRING { + return "", contextError(wd, set, exp.Pos(), fmt.Errorf("expected string literal got %s", Format(exp))) + } + return strconv.Unquote(arg.Value) +} + +func IterateFieldTypes(list []*ast.Field) func(func(int, ast.Expr) bool) { + return func(yield func(int, ast.Expr) bool) { + i := 0 + for _, field := range list { + if len(field.Names) == 0 { + if !yield(i, field.Type) { + return + } + i++ + } else { + for range field.Names { + if !yield(i, field.Type) { + return + } + i++ + } + } + } + } +} + +var httpCodes = map[int]string{ + http.StatusContinue: "StatusContinue", + http.StatusSwitchingProtocols: "StatusSwitchingProtocols", + http.StatusProcessing: "StatusProcessing", + http.StatusEarlyHints: "StatusEarlyHints", + + http.StatusOK: "StatusOK", + http.StatusCreated: "StatusCreated", + http.StatusAccepted: "StatusAccepted", + http.StatusNonAuthoritativeInfo: "StatusNonAuthoritativeInfo", + http.StatusNoContent: "StatusNoContent", + http.StatusResetContent: "StatusResetContent", + http.StatusPartialContent: "StatusPartialContent", + http.StatusMultiStatus: "StatusMultiStatus", + http.StatusAlreadyReported: "StatusAlreadyReported", + http.StatusIMUsed: "StatusIMUsed", + + http.StatusMultipleChoices: "StatusMultipleChoices", + http.StatusMovedPermanently: "StatusMovedPermanently", + http.StatusFound: "StatusFound", + http.StatusSeeOther: "StatusSeeOther", + http.StatusNotModified: "StatusNotModified", + http.StatusUseProxy: "StatusUseProxy", + http.StatusTemporaryRedirect: "StatusTemporaryRedirect", + http.StatusPermanentRedirect: "StatusPermanentRedirect", + + http.StatusBadRequest: "StatusBadRequest", + http.StatusUnauthorized: "StatusUnauthorized", + http.StatusPaymentRequired: "StatusPaymentRequired", + http.StatusForbidden: "StatusForbidden", + http.StatusNotFound: "StatusNotFound", + http.StatusMethodNotAllowed: "StatusMethodNotAllowed", + http.StatusNotAcceptable: "StatusNotAcceptable", + http.StatusProxyAuthRequired: "StatusProxyAuthRequired", + http.StatusRequestTimeout: "StatusRequestTimeout", + http.StatusConflict: "StatusConflict", + http.StatusGone: "StatusGone", + http.StatusLengthRequired: "StatusLengthRequired", + http.StatusPreconditionFailed: "StatusPreconditionFailed", + http.StatusRequestEntityTooLarge: "StatusRequestEntityTooLarge", + http.StatusRequestURITooLong: "StatusRequestURITooLong", + http.StatusUnsupportedMediaType: "StatusUnsupportedMediaType", + http.StatusRequestedRangeNotSatisfiable: "StatusRequestedRangeNotSatisfiable", + http.StatusExpectationFailed: "StatusExpectationFailed", + http.StatusTeapot: "StatusTeapot", + http.StatusMisdirectedRequest: "StatusMisdirectedRequest", + http.StatusUnprocessableEntity: "StatusUnprocessableEntity", + http.StatusLocked: "StatusLocked", + http.StatusFailedDependency: "StatusFailedDependency", + http.StatusTooEarly: "StatusTooEarly", + http.StatusUpgradeRequired: "StatusUpgradeRequired", + http.StatusPreconditionRequired: "StatusPreconditionRequired", + http.StatusTooManyRequests: "StatusTooManyRequests", + http.StatusRequestHeaderFieldsTooLarge: "StatusRequestHeaderFieldsTooLarge", + http.StatusUnavailableForLegalReasons: "StatusUnavailableForLegalReasons", + + http.StatusInternalServerError: "StatusInternalServerError", + http.StatusNotImplemented: "StatusNotImplemented", + http.StatusBadGateway: "StatusBadGateway", + http.StatusServiceUnavailable: "StatusServiceUnavailable", + http.StatusGatewayTimeout: "StatusGatewayTimeout", + http.StatusHTTPVersionNotSupported: "StatusHTTPVersionNotSupported", + http.StatusVariantAlsoNegotiates: "StatusVariantAlsoNegotiates", + http.StatusInsufficientStorage: "StatusInsufficientStorage", + http.StatusLoopDetected: "StatusLoopDetected", + http.StatusNotExtended: "StatusNotExtended", + http.StatusNetworkAuthenticationRequired: "StatusNetworkAuthenticationRequired", +} + +func HTTPStatusCode(pkg string, n int) ast.Expr { + ident, ok := httpCodes[n] + if !ok { + return &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(n)} + } + return &ast.SelectorExpr{ + X: ast.NewIdent(pkg), + Sel: ast.NewIdent(ident), + } +} + +func Int(n int) *ast.BasicLit { return &ast.BasicLit{Value: strconv.Itoa(n), Kind: token.INT} } + +func ErrorCheckReturn(errVarIdent string, body ...ast.Stmt) *ast.IfStmt { + return &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent(errVarIdent), Op: token.NEQ, Y: ast.NewIdent("nil")}, + Body: &ast.BlockStmt{List: body}, + } +} + +func FieldIndex(fields []*ast.Field, i int) (*ast.Ident, ast.Expr, bool) { + n := 0 + for _, field := range fields { + for _, name := range field.Names { + if n != i { + n++ + continue + } + return name, field.Type, true + } + } + return nil, nil, false +} diff --git a/internal/source/ast_test.go b/internal/source/go_test.go similarity index 100% rename from internal/source/ast_test.go rename to internal/source/go_test.go diff --git a/name.go b/name.go index 3437d38..a92d9b9 100644 --- a/name.go +++ b/name.go @@ -173,8 +173,7 @@ const ( TemplateNameScopeIdentifierHTTPRequest = "request" TemplateNameScopeIdentifierHTTPResponse = "response" TemplateNameScopeIdentifierContext = "ctx" - TemplateNameScopeIdentifierTemplate = "template" - TemplateNameScopeIdentifierLogger = "logger" + TemplateNameScopeIdentifierForm = "form" ) func patternScope() []string { @@ -182,7 +181,6 @@ func patternScope() []string { TemplateNameScopeIdentifierHTTPRequest, TemplateNameScopeIdentifierHTTPResponse, TemplateNameScopeIdentifierContext, - TemplateNameScopeIdentifierTemplate, - TemplateNameScopeIdentifierLogger, + TemplateNameScopeIdentifierForm, } }