Skip to content

Commit

Permalink
feat: support encoding.TestUnmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Dec 27, 2024
1 parent 77b47f0 commit 20da3f5
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
56 changes: 55 additions & 1 deletion routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT})

patterns := []string{
wd, "net/http",
wd, "net/http", "encoding",
}

if config.ReceiverPackage != "" {
Expand Down Expand Up @@ -571,6 +571,60 @@ func generateParseValueFromStringStatements(imports *source.Imports, tmp string,
if tp.Obj().Pkg().Path() == "time" && tp.Obj().Name() == "Time" {
return parseBlock(tmp, imports.TimeParseCall(time.DateOnly, str), validations, errCheck, assignment), nil
}
if encPkg, ok := imports.Types("encoding"); ok {
if textUnmarshaler := encPkg.Scope().Lookup("TextUnmarshaler").Type().Underlying().(*types.Interface); types.Implements(types.NewPointer(tp), textUnmarshaler) {
tp, _ := astTypeExpression(imports, valueType)
return []ast.Stmt{
&ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{ast.NewIdent(tmp)},
Type: tp,
},
},
},
},
&ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(errIdent)},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(tmp),
Sel: ast.NewIdent("UnmarshalText"),
},
Args: []ast.Expr{&ast.CallExpr{
Fun: &ast.ArrayType{
Elt: ast.NewIdent("byte"),
},
Args: []ast.Expr{str},
}},
}},
},
Cond: &ast.BinaryExpr{
X: ast.NewIdent(errIdent),
Op: token.NEQ,
Y: ast.NewIdent("nil"),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
errCheck(&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent(errIdent),
Sel: ast.NewIdent("Error"),
},
Args: []ast.Expr{},
}),
new(ast.ReturnStmt),
},
},
},
assignment(ast.NewIdent(tmp)),
}, nil
}
}
}
tp, _ := astTypeExpression(imports, valueType)
return nil, fmt.Errorf("unsupported type: %s", source.Format(tp))
Expand Down
44 changes: 44 additions & 0 deletions routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,50 @@ func routes(mux *http.ServeMux, receiver RoutesReceiver) {
execute(response, request, true, "GET / F()", http.StatusOK, data)
})
}
`,
},
{
Name: "use text encoding",
Templates: `{{define "GET /{id} F(id)"}}{{.}}{{end}}`,
Receiver: "Server",
ReceiverPackage: `-- f.go --
package main
type ID int
func (id *ID) UnmarshalText(text []byte) error {
n, err := strconv.ParseUint(string(text), 2, 64)
if err != nil {
return err
}
*id = n
return nil
}
type Server struct{}
func (Server) F(id ID) int { return int(id) }
` + executeGo,
ExpectedFile: `package main
import "net/http"
type RoutesReceiver interface {
F(id ID) int
}
func routes(mux *http.ServeMux, receiver RoutesReceiver) {
mux.HandleFunc("GET /{id}", func(response http.ResponseWriter, request *http.Request) {
var idParsed ID
if err := idParsed.UnmarshalText([]byte(request.PathValue("id"))); err != nil {
http.Error(response, err.Error(), http.StatusBadRequest)
return
}
id := idParsed
data := receiver.F(id)
execute(response, request, true, "GET /{id} F(id)", http.StatusOK, data)
})
}
`,
},
} {
Expand Down

0 comments on commit 20da3f5

Please sign in to comment.