Skip to content

Commit

Permalink
refactor: use only one config struct
Browse files Browse the repository at this point in the history
  • Loading branch information
crhntr committed Nov 30, 2024
1 parent 6c2e2a1 commit eae4bc5
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 93 deletions.
8 changes: 3 additions & 5 deletions cmd/generate-readme/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"text/template"

"github.com/crhntr/muxt"
"github.com/crhntr/muxt/internal/configuration"
)

Expand All @@ -17,11 +18,8 @@ var (
)

func main() {
var (
out bytes.Buffer
g configuration.Generate
)
gf := g.FlagSet()
var out bytes.Buffer
gf := configuration.RoutesFileConfigurationFlagSet(new(muxt.RoutesFileConfiguration))
gf.SetOutput(&out)
gf.Usage()
generateUsage := out.Bytes()
Expand Down
14 changes: 3 additions & 11 deletions cmd/muxt/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,18 @@ import (
const CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT."

func generateCommand(args []string, workingDirectory string, getEnv func(string) string, stdout, stderr io.Writer) error {
g, err := configuration.NewGenerate(args, getEnv, stderr)
config, err := configuration.NewRoutesFileConfiguration(args, stderr)
if err != nil {
return err
}
s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{
Package: getEnv("GOPACKAGE"),
TemplatesVar: g.TemplatesVariable,
RoutesFunc: g.RoutesFunction,
ReceiverType: g.ReceiverIdent,
ReceiverPackage: g.ReceiverStaticTypePackage,
ReceiverInterface: g.ReceiverInterfaceIdent,
Output: g.OutputFilename,
})
s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), config)
if err != nil {
return err
}
var sb bytes.Buffer
writeCodeGenerationComment(&sb)
sb.WriteString(s)
return os.WriteFile(filepath.Join(workingDirectory, g.OutputFilename), sb.Bytes(), 0o644)
return os.WriteFile(filepath.Join(workingDirectory, config.OutputFileName), sb.Bytes(), 0o644)
}

func writeCodeGenerationComment(w io.StringWriter) {
Expand Down
48 changes: 16 additions & 32 deletions internal/configuration/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,54 +33,38 @@ This function also receives an argument with a type matching the name given by r
errIdentSuffix = " value must be a well-formed Go identifier"
)

type Generate struct {
GoFile string
GoLine string

TemplatesVariable string
OutputFilename string
RoutesFunction string
ReceiverIdent string
ReceiverStaticTypePackage string

ReceiverInterfaceIdent string
}

func NewGenerate(args []string, getEnv func(string) string, stderr io.Writer) (Generate, error) {
g := Generate{
GoFile: getEnv("GOFILE"),
GoLine: getEnv("GOLINE"),
}
flagSet := g.FlagSet()
func NewRoutesFileConfiguration(args []string, stderr io.Writer) (muxt.RoutesFileConfiguration, error) {
var g muxt.RoutesFileConfiguration
flagSet := RoutesFileConfigurationFlagSet(&g)
flagSet.SetOutput(stderr)
if err := flagSet.Parse(args); err != nil {
return g, err
}
if g.TemplatesVariable != "" && !token.IsIdentifier(g.TemplatesVariable) {
return Generate{}, fmt.Errorf(templatesVariable + errIdentSuffix)
return muxt.RoutesFileConfiguration{}, fmt.Errorf(templatesVariable + errIdentSuffix)
}
if g.RoutesFunction != "" && !token.IsIdentifier(g.RoutesFunction) {
return Generate{}, fmt.Errorf(routesFunc + errIdentSuffix)
return muxt.RoutesFileConfiguration{}, fmt.Errorf(routesFunc + errIdentSuffix)
}
if g.ReceiverIdent != "" && !token.IsIdentifier(g.ReceiverIdent) {
return Generate{}, fmt.Errorf(receiverStaticType + errIdentSuffix)
if g.ReceiverType != "" && !token.IsIdentifier(g.ReceiverType) {
return muxt.RoutesFileConfiguration{}, fmt.Errorf(receiverStaticType + errIdentSuffix)
}
if g.ReceiverInterfaceIdent != "" && !token.IsIdentifier(g.ReceiverInterfaceIdent) {
return Generate{}, fmt.Errorf(receiverInterfaceName + errIdentSuffix)
if g.ReceiverInterface != "" && !token.IsIdentifier(g.ReceiverInterface) {
return muxt.RoutesFileConfiguration{}, fmt.Errorf(receiverInterfaceName + errIdentSuffix)
}
if g.OutputFilename != "" && filepath.Ext(g.OutputFilename) != ".go" {
return Generate{}, fmt.Errorf("output filename must use .go extension")
if g.OutputFileName != "" && filepath.Ext(g.OutputFileName) != ".go" {
return muxt.RoutesFileConfiguration{}, fmt.Errorf("output filename must use .go extension")
}
return g, nil
}

func (g *Generate) FlagSet() *flag.FlagSet {
func RoutesFileConfigurationFlagSet(g *muxt.RoutesFileConfiguration) *flag.FlagSet {
flagSet := flag.NewFlagSet("generate", flag.ContinueOnError)
flagSet.StringVar(&g.OutputFilename, outputFlagName, muxt.DefaultOutputFileName, outputFlagNameHelp)
flagSet.StringVar(&g.OutputFileName, outputFlagName, muxt.DefaultOutputFileName, outputFlagNameHelp)
flagSet.StringVar(&g.TemplatesVariable, templatesVariable, muxt.DefaultTemplatesVariableName, templatesVariableHelp)
flagSet.StringVar(&g.RoutesFunction, routesFunc, muxt.DefaultRoutesFunctionName, routesFuncHelp)
flagSet.StringVar(&g.ReceiverIdent, receiverStaticType, "", receiverStaticTypeHelp)
flagSet.StringVar(&g.ReceiverStaticTypePackage, receiverStaticTypePackage, "", receiverStaticTypePackageHelp)
flagSet.StringVar(&g.ReceiverInterfaceIdent, receiverInterfaceName, muxt.DefaultReceiverInterfaceName, receiverInterfaceNameHelp)
flagSet.StringVar(&g.ReceiverType, receiverStaticType, "", receiverStaticTypeHelp)
flagSet.StringVar(&g.ReceiverPackage, receiverStaticTypePackage, "", receiverStaticTypePackageHelp)
flagSet.StringVar(&g.ReceiverInterface, receiverInterfaceName, muxt.DefaultReceiverInterfaceName, receiverInterfaceNameHelp)
return flagSet
}
20 changes: 10 additions & 10 deletions internal/configuration/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@ import (

func TestNewGenerate(t *testing.T) {
t.Run("unknown flag", func(t *testing.T) {
_, err := NewGenerate([]string{
_, err := NewRoutesFileConfiguration([]string{
"--unknown",
}, func(s string) string { return "" }, io.Discard)
}, io.Discard)
assert.ErrorContains(t, err, "flag provided but not defined")
})
t.Run(receiverStaticType+" flag value is an invalid identifier", func(t *testing.T) {
_, err := NewGenerate([]string{
_, err := NewRoutesFileConfiguration([]string{
"--" + receiverStaticType, "123",
}, func(s string) string { return "" }, io.Discard)
}, io.Discard)
assert.ErrorContains(t, err, errIdentSuffix)
})
t.Run(routesFunc+" flag value is an invalid identifier", func(t *testing.T) {
_, err := NewGenerate([]string{
_, err := NewRoutesFileConfiguration([]string{
"--" + routesFunc, "123",
}, func(s string) string { return "" }, io.Discard)
}, io.Discard)
assert.ErrorContains(t, err, errIdentSuffix)
})
t.Run(templatesVariable+" flag value is an invalid identifier", func(t *testing.T) {
_, err := NewGenerate([]string{
_, err := NewRoutesFileConfiguration([]string{
"--" + templatesVariable, "123",
}, func(s string) string { return "" }, io.Discard)
}, io.Discard)
assert.ErrorContains(t, err, errIdentSuffix)
})
t.Run(outputFlagName+" flag value is not a go file", func(t *testing.T) {
_, err := NewGenerate([]string{
_, err := NewRoutesFileConfiguration([]string{
"--" + outputFlagName, "output.txt",
}, func(s string) string { return "" }, io.Discard)
}, io.Discard)
assert.ErrorContains(t, err, "filename must use .go extension")
})
}
30 changes: 15 additions & 15 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,29 @@ const (

type RoutesFileConfiguration struct {
executeFunc bool
Package,
PackageName,
PackagePath,
TemplatesVar,
RoutesFunc,
TemplatesVariable,
RoutesFunction,
ReceiverType,
ReceiverPackage,
ReceiverInterface,
Output string
OutputFileName string
}

func (config RoutesFileConfiguration) applyDefaults() RoutesFileConfiguration {
config.Package = cmp.Or(config.Package, defaultPackageName)
config.TemplatesVar = cmp.Or(config.TemplatesVar, DefaultTemplatesVariableName)
config.RoutesFunc = cmp.Or(config.RoutesFunc, DefaultRoutesFunctionName)
config.PackageName = cmp.Or(config.PackageName, defaultPackageName)
config.TemplatesVariable = cmp.Or(config.TemplatesVariable, DefaultTemplatesVariableName)
config.RoutesFunction = cmp.Or(config.RoutesFunction, DefaultRoutesFunctionName)
config.ReceiverInterface = cmp.Or(config.ReceiverInterface, DefaultReceiverInterfaceName)
config.executeFunc = true
return config
}

func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfiguration) (string, error) {
config = config.applyDefaults()
if !token.IsIdentifier(config.Package) {
return "", fmt.Errorf("package name %q is not an identifier", config.Package)
if !token.IsIdentifier(config.PackageName) {
return "", fmt.Errorf("package name %q is not an identifier", config.PackageName)
}
imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT})

Expand All @@ -104,7 +104,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
}
imports.SetOutputPackage(routesPkg.Types)
config.PackagePath = routesPkg.PkgPath
config.Package = routesPkg.Name
config.PackageName = routesPkg.Name
var receiver *types.Named
if config.ReceiverType != "" {
receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath)
Expand All @@ -125,7 +125,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil)
}

ts, err := source.Templates(wd, config.TemplatesVar, routesPkg)
ts, err := source.Templates(wd, config.TemplatesVariable, routesPkg)
if err != nil {
return "", err
}
Expand All @@ -138,7 +138,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
if p.Types.Path() == config.PackagePath {
if executeObj := p.Types.Scope().Lookup("execute"); executeObj != nil {
if _, ok := executeObj.(*types.Func); ok {
config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output
config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.OutputFileName
}
}
break
Expand All @@ -150,7 +150,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
}

routesFunc := &ast.FuncDecl{
Name: ast.NewIdent(config.RoutesFunc),
Name: ast.NewIdent(config.RoutesFunction),
Type: &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{
Expand Down Expand Up @@ -209,7 +209,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur

imports.SortImports()
file := &ast.File{
Name: ast.NewIdent(config.Package),
Name: ast.NewIdent(config.PackageName),
Decls: []ast.Decl{
// import
imports.GenDecl,
Expand All @@ -228,7 +228,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur
}

if config.executeFunc {
file.Decls = append(file.Decls, executeFuncDecl(imports, config.TemplatesVar))
file.Decls = append(file.Decls, executeFuncDecl(imports, config.TemplatesVariable))
}

return source.Format(file), nil
Expand Down
24 changes: 4 additions & 20 deletions routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1946,12 +1946,12 @@ var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml"))
logger := log.New(io.Discard, "", 0)
out, err := muxt.TemplateRoutesFile(dir, logger, muxt.RoutesFileConfiguration{
ReceiverInterface: tt.Interface,
Package: tt.PackageName,
TemplatesVar: tt.TemplatesVar,
RoutesFunc: tt.RoutesFunc,
PackageName: tt.PackageName,
TemplatesVariable: tt.TemplatesVar,
RoutesFunction: tt.RoutesFunc,
PackagePath: "example.com",
ReceiverType: tt.Receiver,
Output: "template_routes.go",
OutputFileName: "template_routes.go",
})
if tt.ExpectedError == "" {
require.NoError(t, err)
Expand All @@ -1963,22 +1963,6 @@ var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml"))
}
}

//func loadPackage(t *testing.T, in string) []*packages.Package {
// t.Helper()
// archive := txtar.Parse([]byte(in))
// archiveDir, err := txtar.FS(archive)
// require.NoError(t, err)
//
// dir := t.TempDir()
// require.NoError(t, os.CopyFS(dir, archiveDir))
// require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n"), 0644))
//
// packageList, err := source.Load(dir, "./...")
// require.NoError(t, err)
//
// return packageList
//}

const executeGo = `-- execute.go --
package main
Expand Down

0 comments on commit eae4bc5

Please sign in to comment.