Skip to content

Commit

Permalink
refactor: use types package instead of ast
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Nov 14, 2024
1 parent 6d44e72 commit 79990a3
Show file tree
Hide file tree
Showing 9 changed files with 945 additions and 744 deletions.
14 changes: 9 additions & 5 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,15 @@ func generateCommand(args []string, workingDirectory string, getEnv func(string)
if err != nil {
return err
}
out := log.New(stdout, "", 0)
s, err := muxt.Routes(templates, g.goPackage, g.templatesVariable, g.routesFunction, g.receiverIdent, g.receiverInterfaceIdent, g.outputFilename, packageList, out)
if err != nil {
return err
}
s, err := muxt.TemplateRoutesFile(workingDirectory, templates, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{
Package: g.goPackage,
PackagePath: g.Package.PkgPath,
TemplatesVar: g.templatesVariable,
RoutesFunc: g.routesFunction,
ReceiverType: g.receiverIdent,
ReceiverInterface: g.receiverInterfaceIdent,
Output: g.outputFilename,
})
var sb bytes.Buffer
sb.WriteString(CodeGenerationComment)
if v, ok := cliVersion(); ok {
Expand Down
3 changes: 2 additions & 1 deletion internal/source/html.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"go/ast"
"go/token"
"go/types"
"regexp"
"slices"
"strings"
Expand All @@ -17,7 +18,7 @@ type ValidationGenerator interface {
GenerateValidation(imports *Imports, variable ast.Expr, handleError func(string) ast.Stmt) ast.Stmt
}

func ParseInputValidations(name string, input spec.Element, tp ast.Expr) ([]ValidationGenerator, error) {
func ParseInputValidations(name string, input spec.Element, tp types.Type) ([]ValidationGenerator, error) {
if tag := strings.ToLower(input.TagName()); tag != atom.Input.String() {
return nil, fmt.Errorf("expected element to have tag <input> got <%s>", tag)
}
Expand Down
122 changes: 103 additions & 19 deletions internal/source/imports.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package source

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"log"
"path"
"slices"
Expand All @@ -12,6 +15,10 @@ import (

type Imports struct {
*ast.GenDecl
fileSet *token.FileSet
types map[string]*types.Package
files map[string]*ast.File
outputPackage string
}

func NewImports(decl *ast.GenDecl) *Imports {
Expand All @@ -20,7 +27,82 @@ func NewImports(decl *ast.GenDecl) *Imports {
log.Panicf("expected decl to have token.IMPORT Tok got %s", got)
}
}
return &Imports{GenDecl: decl}
return &Imports{GenDecl: decl, types: make(map[string]*types.Package), files: make(map[string]*ast.File)}
}

func (imports *Imports) AddPackages(p *types.Package) {
recursivelyRegisterPackages(imports.types, p)
}

func (imports *Imports) FileSet() *token.FileSet {
if imports.fileSet == nil {
imports.fileSet = token.NewFileSet()
}
return imports.fileSet
}

func (imports *Imports) SetOutputPackage(pkgPath string) {
imports.outputPackage = pkgPath
}

func (imports *Imports) OutputPackage() string {
return imports.outputPackage
}

func (imports *Imports) SyntaxFile(pos token.Pos) (*ast.File, *token.FileSet, error) {
position := imports.FileSet().Position(pos)
fSet := token.NewFileSet()
file, err := parser.ParseFile(fSet, position.Filename, nil, parser.AllErrors|parser.ParseComments|parser.SkipObjectResolution)
return file, fSet, err
}

func (imports *Imports) FieldTag(pos token.Pos) (*ast.Field, error) {
file, fileSet, err := imports.SyntaxFile(pos)
if err != nil {
return nil, err
}
position := imports.fileSet.Position(pos)
for _, d := range file.Decls {
switch decl := d.(type) {
case *ast.GenDecl:
for _, s := range decl.Specs {
switch spec := s.(type) {
case *ast.TypeSpec:
tp, ok := spec.Type.(*ast.StructType)
if !ok {
continue
}

for _, field := range tp.Fields.List {
for _, name := range field.Names {
p := fileSet.Position(name.Pos())
if p != position {
continue
}
return field, nil
}
}
}
}
}

}
return nil, fmt.Errorf("failed to find field")
}

func (imports *Imports) Types(pkgPath string) (*types.Package, bool) {
p, ok := imports.types[pkgPath]
return p, ok
}

func recursivelyRegisterPackages(set map[string]*types.Package, pkg *types.Package) {
if pkg == nil {
return
}
set[pkg.Path()] = pkg
for _, p := range pkg.Imports() {
recursivelyRegisterPackages(set, p)
}
}

func (imports *Imports) Add(pkgIdent, pkgPath string) string {
Expand All @@ -31,27 +113,29 @@ func (imports *Imports) Add(pkgIdent, pkgPath string) string {
if pkgIdent == "" {
pkgIdent = path.Base(pkgPath)
}
for _, s := range imports.GenDecl.Specs {
spec := s.(*ast.ImportSpec)
pp, _ := strconv.Unquote(spec.Path.Value)
if pp == pkgPath {
if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent {
return spec.Name.Name
if pkgPath != imports.outputPackage {
for _, s := range imports.GenDecl.Specs {
spec := s.(*ast.ImportSpec)
pp, _ := strconv.Unquote(spec.Path.Value)
if pp == pkgPath {
if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != pkgIdent {
return spec.Name.Name
}
return path.Base(pp)
}
return path.Base(pp)
}
var pi *ast.Ident
if path.Base(pkgPath) != pkgIdent {
pi = Ident(pkgIdent)
}
imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{
Path: String(pkgPath),
Name: pi,
})
slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int {
return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value)
})
}
var pi *ast.Ident
if path.Base(pkgPath) != pkgIdent {
pi = Ident(pkgIdent)
}
imports.GenDecl.Specs = append(imports.GenDecl.Specs, &ast.ImportSpec{
Path: String(pkgPath),
Name: pi,
})
slices.SortFunc(imports.GenDecl.Specs, func(a, b ast.Spec) int {
return strings.Compare(a.(*ast.ImportSpec).Path.Value, b.(*ast.ImportSpec).Path.Value)
})
return pkgIdent
}

Expand Down
3 changes: 2 additions & 1 deletion internal/source/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"go/ast"
"go/token"
"go/types"
"net/http"
"regexp"
"slices"
Expand Down Expand Up @@ -71,7 +72,7 @@ func GenerateParseValueFromStringStatements(imports *Imports, tmp string, str, t
return nil, fmt.Errorf("unsupported type: %s", Format(typeExp))
}

func GenerateValidations(imports *Imports, variable, variableType ast.Expr, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) {
func GenerateValidations(imports *Imports, variable ast.Expr, variableType types.Type, inputQuery, inputName, responseIdent string, fragment spec.DocumentFragment) ([]ast.Stmt, error, bool) {
input := fragment.QuerySelector(inputQuery)
if input == nil {
return nil, nil, false
Expand Down
Loading

0 comments on commit 79990a3

Please sign in to comment.