diff --git a/cmd/muxt/testdata/generate/path_param_typed.txtar b/cmd/muxt/testdata/generate/path_param_typed.txtar new file mode 100644 index 0000000..9a97fab --- /dev/null +++ b/cmd/muxt/testdata/generate/path_param_typed.txtar @@ -0,0 +1,168 @@ +muxt generate --receiver-static-type=T + +cat template_routes.go + +exec go test -cover + +-- template.gohtml -- + +{{- define "GET /bool/{value} PassBool(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int/{value} PassInt(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int16/{value} PassInt16(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int32/{value} PassInt32(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int64/{value} PassInt64(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int8/{value} PassInt8(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint/{value} PassUint(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint16/{value} PassUint16(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint32/{value} PassUint32(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint64/{value} PassUint64(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint8/{value} PassUint8(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} + +-- go.mod -- +module server + +go 1.22 + +-- template.go -- +package server + +import ( + "embed" + "html/template" +) + +//go:embed *.gohtml +var formHTML embed.FS + +var templates = template.Must(template.ParseFS(formHTML, "*")) + +type T struct{} + +func (T) PassInt(in int) int { return in } +func (T) PassInt64(in int64) int64 { return in } +func (T) PassInt32(in int32) int32 { return in } +func (T) PassInt16(in int16) int16 { return in } +func (T) PassInt8(in int8) int8 { return in } +func (T) PassUint(in uint) uint { return in } +func (T) PassUint64(in uint64) uint64 { return in } +func (T) PassUint32(in uint32) uint32 { return in } +func (T) PassUint16(in uint16) uint16 { return in } +func (T) PassUint8(in uint8) uint8 { return in } +func (T) PassBool(in bool) bool { return in } +func (T) PassByte(in byte) byte { return in } +func (T) PassRune(in rune) rune { return in } + +-- template_test.go -- +package server + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func Test(t *testing.T) { + mux := http.NewServeMux() + + routes(mux, T{}) + + t.Run("int", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/int/123", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("int64", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/int64/52", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("int32", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/int32/51", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("int16", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/int16/50", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("int8", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/int8/50", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("uint", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/uint/12", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("uint64", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/uint64/11", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("uint32", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/uint32/11", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("uint16", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/uint16/7", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("uint8", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/uint8/5", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) + t.Run("bool", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/bool/true", nil) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != http.StatusOK { + t.Error("expected OK") + } + }) +} diff --git a/generate.go b/generate.go index a74c35e..7be8357 100644 --- a/generate.go +++ b/generate.go @@ -29,6 +29,7 @@ const ( httpRequestIdent = "Request" httpStatusCode200Ident = "StatusOK" httpStatusCode500Ident = "StatusInternalServerError" + httpStatusCode400Ident = "StatusBadRequest" httpHandleFuncIdent = "HandleFunc" contextPackageIdent = "context" @@ -69,6 +70,7 @@ func Generate(templateNames []TemplateName, packageName, templatesVariableName, continue } method = funcDecl.Type + break } if method == nil { me, im := pattern.funcType() @@ -143,13 +145,13 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT return nil, nil, errWrongNumberOfArguments(def, method) } for pi, pt := range fieldListTypes(method.Params) { - if err := checkArgument(def.call.Args[pi], pt); err != nil { + if err := checkArgument(method, pi, def.call.Args[pi], pt); err != nil { return nil, nil, err } } } var imports []*ast.ImportSpec - for _, a := range def.call.Args { + for i, a := range def.call.Args { arg := a.(*ast.Ident) switch arg.Name { case TemplateNameScopeIdentifierHTTPRequest, TemplateNameScopeIdentifierHTTPResponse: @@ -160,8 +162,13 @@ func (def TemplateName) funcLit(templatesVariableIdent string, method *ast.FuncT call.Args = append(call.Args, ast.NewIdent(TemplateNameScopeIdentifierContext)) imports = append(imports, importSpec("context")) default: - lit.Body.List = append(lit.Body.List, httpPathValueAssignment(arg)) + statements, parseImports, err := httpPathValueAssignment(method, i, arg) + if err != nil { + return nil, nil, err + } + lit.Body.List = append(lit.Body.List, statements...) call.Args = append(call.Args, ast.NewIdent(arg.Name)) + imports = append(imports, parseImports...) } } @@ -258,28 +265,35 @@ func errWrongNumberOfArguments(def TemplateName, method *ast.FuncType) error { return fmt.Errorf("handler %s expects %d arguments but call %s has %d", source.Format(&ast.FuncDecl{Name: ast.NewIdent(def.fun.Name), Type: method}), method.Params.NumFields(), def.handler, len(def.call.Args)) } -func checkArgument(exp ast.Expr, tp ast.Expr) error { +func checkArgument(method *ast.FuncType, argIndex int, exp ast.Expr, argType ast.Expr) error { arg := exp.(*ast.Ident) switch arg.Name { case TemplateNameScopeIdentifierHTTPRequest: - if !matchSelectorIdents(tp, httpPackageIdent, httpRequestIdent, true) { - return fmt.Errorf("method expects type %s but %s is *%s.%s", source.Format(tp), arg.Name, httpPackageIdent, httpRequestIdent) + if !matchSelectorIdents(argType, httpPackageIdent, httpRequestIdent, true) { + return fmt.Errorf("method expects type %s but %s is *%s.%s", source.Format(argType), arg.Name, httpPackageIdent, httpRequestIdent) } return nil case TemplateNameScopeIdentifierHTTPResponse: - if !matchSelectorIdents(tp, httpPackageIdent, httpResponseWriterIdent, false) { - return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(tp), arg.Name, httpPackageIdent, httpResponseWriterIdent) + if !matchSelectorIdents(argType, httpPackageIdent, httpResponseWriterIdent, false) { + return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(argType), arg.Name, httpPackageIdent, httpResponseWriterIdent) } return nil case TemplateNameScopeIdentifierContext: - if !matchSelectorIdents(tp, contextPackageIdent, contextContextTypeIdent, false) { - return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(tp), arg.Name, contextPackageIdent, contextContextTypeIdent) + if !matchSelectorIdents(argType, contextPackageIdent, contextContextTypeIdent, false) { + return fmt.Errorf("method expects type %s but %s is %s.%s", source.Format(argType), arg.Name, contextPackageIdent, contextContextTypeIdent) } return nil default: - ident, ok := tp.(*ast.Ident) - if !ok || ident.Name != stringTypeIdent { - return fmt.Errorf("method expects type %s but %s is a string", source.Format(tp), arg.Name) + for paramIndex, paramType := range source.IterateFieldTypes(method.Params.List) { + if argIndex != paramIndex { + continue + } + paramTypeIdent, paramOk := paramType.(*ast.Ident) + argTypeIdent, argOk := argType.(*ast.Ident) + if !argOk || !paramOk || argTypeIdent.Name != paramTypeIdent.Name { + return fmt.Errorf("method expects type %s but %s is a %s", source.Format(argType), arg.Name, paramTypeIdent.Name) + } + break } return nil } @@ -379,22 +393,448 @@ func contextAssignment() *ast.AssignStmt { } } -func httpPathValueAssignment(arg *ast.Ident) *ast.AssignStmt { - return &ast.AssignStmt{ - Tok: token.DEFINE, - Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, - Rhs: []ast.Expr{&ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), - Sel: ast.NewIdent(requestPathValue), - }, - Args: []ast.Expr{ - &ast.BasicLit{ - Kind: token.STRING, - Value: strconv.Quote(arg.Name), - }, +func httpPathValueAssignment(method *ast.FuncType, i int, arg *ast.Ident) ([]ast.Stmt, []*ast.ImportSpec, error) { + const parsedVarSuffix = "Parsed" + for typeIndex, typeExp := range source.IterateFieldTypes(method.Params.List) { + if typeIndex != i { + continue + } + paramTypeIdent, ok := typeExp.(*ast.Ident) + if !ok { + return nil, nil, fmt.Errorf("unsupported type: %s", source.Format(typeExp)) + } + switch paramTypeIdent.Name { + default: + return nil, nil, fmt.Errorf("method param type %s not supported", source.Format(typeExp)) + case "bool": + errVar := ast.NewIdent("err") + + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseBool"), + }, + Args: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }}, + }}, + } + + errCheck := paramParseError(errVar) + + return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "64", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int16": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "16", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int32": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "32", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int8": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "8", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "int64": + errVar := ast.NewIdent("err") + + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseInt"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "64", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + + return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "64", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint16": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "16", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint32": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "32", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint64": + + errVar := ast.NewIdent("err") + + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "64", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + + return []ast.Stmt{assign, errCheck}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "uint8": + errVar := ast.NewIdent("err") + + tmp := arg.Name + parsedVarSuffix + + parse := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(tmp), ast.NewIdent(errVar.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("strconv"), + Sel: ast.NewIdent("ParseUint"), + }, + Args: []ast.Expr{ + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }, + &ast.BasicLit{Value: "10", Kind: token.INT}, + &ast.BasicLit{Value: "8", Kind: token.INT}, + }, + }}, + } + + errCheck := paramParseError(errVar) + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: ast.NewIdent(paramTypeIdent.Name), + Args: []ast.Expr{ast.NewIdent(tmp)}, + }}, + } + + return []ast.Stmt{parse, errCheck, assign}, []*ast.ImportSpec{importSpec("strconv")}, nil + case "string": + assign := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(arg.Name)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(TemplateNameScopeIdentifierHTTPRequest), + Sel: ast.NewIdent(requestPathValue), + }, + Args: []ast.Expr{&ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(arg.Name)}}, + }}, + } + + return []ast.Stmt{assign}, nil, nil + } + } + return nil, nil, fmt.Errorf("type for argumement %d not found", i) +} + +func paramParseError(errVar *ast.Ident) *ast.IfStmt { + return &ast.IfStmt{ + Cond: &ast.BinaryExpr{X: ast.NewIdent(errVar.Name), Op: token.NEQ, Y: ast.NewIdent("nil")}, + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ExprStmt{X: &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent(httpPackageIdent), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{ + ast.NewIdent(httpResponseField().Names[0].Name), + &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: ast.NewIdent("err"), + Sel: ast.NewIdent("Error"), + }, + Args: []ast.Expr{}, + }, + httpStatusCode(httpStatusCode400Ident), + }, + }}, + &ast.ReturnStmt{}, }, - }}, + }, } } diff --git a/generate_internal_test.go b/generate_internal_test.go index 579b3bb..c7b1675 100644 --- a/generate_internal_test.go +++ b/generate_internal_test.go @@ -178,11 +178,11 @@ func TestTemplateName_HandlerFuncLit_err(t *testing.T) { In: "GET /{name} F(name)", Method: &ast.FuncType{ Params: &ast.FieldList{List: []*ast.Field{ - {Type: ast.NewIdent("int")}, + {Type: ast.NewIdent("float64")}, }}, Results: &ast.FieldList{List: []*ast.Field{{Type: ast.NewIdent("any")}}}, }, - ErrSub: "method expects type int but name is a string", + ErrSub: "method param type float64 not supported", }, { Name: "wrong argument type request ptr", diff --git a/generate_test.go b/generate_test.go index c79aa72..a9524d5 100644 --- a/generate_test.go +++ b/generate_test.go @@ -364,6 +364,196 @@ func execute(response http.ResponseWriter, request *http.Request, t *template.Te response.WriteHeader(code) _, _ = buf.WriteTo(response) } +`, + }, + { + Name: "when using param parsers", + Receiver: "T", + Templates: ` +{{- define "GET /bool/{value} PassBool(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int/{value} PassInt(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int16/{value} PassInt16(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int32/{value} PassInt32(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int64/{value} PassInt64(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /int8/{value} PassInt8(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint/{value} PassUint(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint16/{value} PassUint16(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint32/{value} PassUint32(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint64/{value} PassUint64(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +{{- define "GET /uint8/{value} PassUint8(value)" -}}

{{- printf "%[1]#v %[1]T" . -}}

{{- end -}} +`, + ReceiverPackage: ` +-- t.go -- +package main + +import ( + "embed" + "html/template" +) + +//go:embed *.gohtml +var formHTML embed.FS + +var templates = template.Must(template.ParseFS(formHTML, "*")) + +type T struct{} + +func (T) PassInt(in int) int { return in } +func (T) PassInt64(in int64) int64 { return in } +func (T) PassInt32(in int32) int32 { return in } +func (T) PassInt16(in int16) int16 { return in } +func (T) PassInt8(in int8) int8 { return in } +func (T) PassUint(in uint) uint { return in } +func (T) PassUint64(in uint64) uint64 { return in } +func (T) PassUint16(in uint16) uint16 { return in } +func (T) PassUint32(in uint32) uint32 { return in } +func (T) PassUint64(in uint16) uint16 { return in } +func (T) PassUint8(in uint8) uint8 { return in } +func (T) PassBool(in bool) bool { return in } +func (T) PassByte(in byte) byte { return in } +func (T) PassRune(in rune) rune { return in } +`, + ExpectedFile: `package main + +import ( + "net/http" + "strconv" + "bytes" + "html/template" +) + +type RoutesReceiver interface { + PassBool(in bool) bool + PassInt(in int) int + PassInt16(in int16) int16 + PassInt32(in int32) int32 + PassInt64(in int64) int64 + PassInt8(in int8) int8 + PassUint(in uint) uint + PassUint16(in uint16) uint16 + PassUint32(in uint32) uint32 + PassUint64(in uint64) uint64 + PassUint8(in uint8) uint8 +} + +func routes(mux *http.ServeMux, receiver RoutesReceiver) { + mux.HandleFunc("GET /bool/{value}", func(response http.ResponseWriter, request *http.Request) { + value, err := strconv.ParseBool(request.PathValue("value")) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + data := receiver.PassBool(value) + execute(response, request, templates.Lookup("GET /bool/{value} PassBool(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /int/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := int(valueParsed) + data := receiver.PassInt(value) + execute(response, request, templates.Lookup("GET /int/{value} PassInt(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /int16/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := int16(valueParsed) + data := receiver.PassInt16(value) + execute(response, request, templates.Lookup("GET /int16/{value} PassInt16(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /int32/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := int32(valueParsed) + data := receiver.PassInt32(value) + execute(response, request, templates.Lookup("GET /int32/{value} PassInt32(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /int64/{value}", func(response http.ResponseWriter, request *http.Request) { + value, err := strconv.ParseInt(request.PathValue("value"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + data := receiver.PassInt64(value) + execute(response, request, templates.Lookup("GET /int64/{value} PassInt64(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /int8/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseInt(request.PathValue("value"), 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := int8(valueParsed) + data := receiver.PassInt8(value) + execute(response, request, templates.Lookup("GET /int8/{value} PassInt8(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /uint/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseUint(request.PathValue("value"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := uint(valueParsed) + data := receiver.PassUint(value) + execute(response, request, templates.Lookup("GET /uint/{value} PassUint(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /uint16/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseUint(request.PathValue("value"), 10, 16) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := uint16(valueParsed) + data := receiver.PassUint16(value) + execute(response, request, templates.Lookup("GET /uint16/{value} PassUint16(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /uint32/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseUint(request.PathValue("value"), 10, 32) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := uint32(valueParsed) + data := receiver.PassUint32(value) + execute(response, request, templates.Lookup("GET /uint32/{value} PassUint32(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /uint64/{value}", func(response http.ResponseWriter, request *http.Request) { + value, err := strconv.ParseUint(request.PathValue("value"), 10, 64) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + data := receiver.PassUint64(value) + execute(response, request, templates.Lookup("GET /uint64/{value} PassUint64(value)"), http.StatusOK, data) + }) + mux.HandleFunc("GET /uint8/{value}", func(response http.ResponseWriter, request *http.Request) { + valueParsed, err := strconv.ParseUint(request.PathValue("value"), 10, 8) + if err != nil { + http.Error(response, err.Error(), http.StatusBadRequest) + return + } + value := uint8(valueParsed) + data := receiver.PassUint8(value) + execute(response, request, templates.Lookup("GET /uint8/{value} PassUint8(value)"), http.StatusOK, data) + }) +} +func execute(response http.ResponseWriter, request *http.Request, t *template.Template, code int, data any) { + buf := bytes.NewBuffer(nil) + if err := t.Execute(buf, data); err != nil { + http.Error(response, err.Error(), http.StatusInternalServerError) + return + } + response.WriteHeader(code) + _, _ = buf.WriteTo(response) +} `, }, } { diff --git a/internal/source/ast.go b/internal/source/ast.go index 0596510..d1e5e3c 100644 --- a/internal/source/ast.go +++ b/internal/source/ast.go @@ -120,3 +120,23 @@ func evaluateStringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr } return strconv.Unquote(arg.Value) } + +func IterateFieldTypes(list []*ast.Field) func(func(int, ast.Expr) bool) { + return func(yield func(int, ast.Expr) bool) { + i := 0 + for _, field := range list { + if len(field.Names) == 0 { + if !yield(i, field.Type) { + return + } + } else { + for range field.Names { + if !yield(i, field.Type) { + return + } + } + } + i++ + } + } +}