From eae4bc56d9d067ad1cfa96cdefeabfe93a1d9442 Mon Sep 17 00:00:00 2001 From: Chrstopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:08:36 -0800 Subject: [PATCH] refactor: use only one config struct --- cmd/generate-readme/main.go | 8 ++--- cmd/muxt/generate.go | 14 ++------ internal/configuration/generate.go | 48 +++++++++---------------- internal/configuration/generate_test.go | 20 +++++------ routes.go | 30 ++++++++-------- routes_test.go | 24 +++---------- 6 files changed, 51 insertions(+), 93 deletions(-) diff --git a/cmd/generate-readme/main.go b/cmd/generate-readme/main.go index 42f0049..b771559 100644 --- a/cmd/generate-readme/main.go +++ b/cmd/generate-readme/main.go @@ -7,6 +7,7 @@ import ( "os" "text/template" + "github.com/crhntr/muxt" "github.com/crhntr/muxt/internal/configuration" ) @@ -17,11 +18,8 @@ var ( ) func main() { - var ( - out bytes.Buffer - g configuration.Generate - ) - gf := g.FlagSet() + var out bytes.Buffer + gf := configuration.RoutesFileConfigurationFlagSet(new(muxt.RoutesFileConfiguration)) gf.SetOutput(&out) gf.Usage() generateUsage := out.Bytes() diff --git a/cmd/muxt/generate.go b/cmd/muxt/generate.go index e580b20..08ac09a 100644 --- a/cmd/muxt/generate.go +++ b/cmd/muxt/generate.go @@ -14,26 +14,18 @@ import ( const CodeGenerationComment = "// Code generated by muxt. DO NOT EDIT." func generateCommand(args []string, workingDirectory string, getEnv func(string) string, stdout, stderr io.Writer) error { - g, err := configuration.NewGenerate(args, getEnv, stderr) + config, err := configuration.NewRoutesFileConfiguration(args, stderr) if err != nil { return err } - s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), muxt.RoutesFileConfiguration{ - Package: getEnv("GOPACKAGE"), - TemplatesVar: g.TemplatesVariable, - RoutesFunc: g.RoutesFunction, - ReceiverType: g.ReceiverIdent, - ReceiverPackage: g.ReceiverStaticTypePackage, - ReceiverInterface: g.ReceiverInterfaceIdent, - Output: g.OutputFilename, - }) + s, err := muxt.TemplateRoutesFile(workingDirectory, log.New(stdout, "", 0), config) if err != nil { return err } var sb bytes.Buffer writeCodeGenerationComment(&sb) sb.WriteString(s) - return os.WriteFile(filepath.Join(workingDirectory, g.OutputFilename), sb.Bytes(), 0o644) + return os.WriteFile(filepath.Join(workingDirectory, config.OutputFileName), sb.Bytes(), 0o644) } func writeCodeGenerationComment(w io.StringWriter) { diff --git a/internal/configuration/generate.go b/internal/configuration/generate.go index a982164..1d94225 100644 --- a/internal/configuration/generate.go +++ b/internal/configuration/generate.go @@ -33,54 +33,38 @@ This function also receives an argument with a type matching the name given by r errIdentSuffix = " value must be a well-formed Go identifier" ) -type Generate struct { - GoFile string - GoLine string - - TemplatesVariable string - OutputFilename string - RoutesFunction string - ReceiverIdent string - ReceiverStaticTypePackage string - - ReceiverInterfaceIdent string -} - -func NewGenerate(args []string, getEnv func(string) string, stderr io.Writer) (Generate, error) { - g := Generate{ - GoFile: getEnv("GOFILE"), - GoLine: getEnv("GOLINE"), - } - flagSet := g.FlagSet() +func NewRoutesFileConfiguration(args []string, stderr io.Writer) (muxt.RoutesFileConfiguration, error) { + var g muxt.RoutesFileConfiguration + flagSet := RoutesFileConfigurationFlagSet(&g) flagSet.SetOutput(stderr) if err := flagSet.Parse(args); err != nil { return g, err } if g.TemplatesVariable != "" && !token.IsIdentifier(g.TemplatesVariable) { - return Generate{}, fmt.Errorf(templatesVariable + errIdentSuffix) + return muxt.RoutesFileConfiguration{}, fmt.Errorf(templatesVariable + errIdentSuffix) } if g.RoutesFunction != "" && !token.IsIdentifier(g.RoutesFunction) { - return Generate{}, fmt.Errorf(routesFunc + errIdentSuffix) + return muxt.RoutesFileConfiguration{}, fmt.Errorf(routesFunc + errIdentSuffix) } - if g.ReceiverIdent != "" && !token.IsIdentifier(g.ReceiverIdent) { - return Generate{}, fmt.Errorf(receiverStaticType + errIdentSuffix) + if g.ReceiverType != "" && !token.IsIdentifier(g.ReceiverType) { + return muxt.RoutesFileConfiguration{}, fmt.Errorf(receiverStaticType + errIdentSuffix) } - if g.ReceiverInterfaceIdent != "" && !token.IsIdentifier(g.ReceiverInterfaceIdent) { - return Generate{}, fmt.Errorf(receiverInterfaceName + errIdentSuffix) + if g.ReceiverInterface != "" && !token.IsIdentifier(g.ReceiverInterface) { + return muxt.RoutesFileConfiguration{}, fmt.Errorf(receiverInterfaceName + errIdentSuffix) } - if g.OutputFilename != "" && filepath.Ext(g.OutputFilename) != ".go" { - return Generate{}, fmt.Errorf("output filename must use .go extension") + if g.OutputFileName != "" && filepath.Ext(g.OutputFileName) != ".go" { + return muxt.RoutesFileConfiguration{}, fmt.Errorf("output filename must use .go extension") } return g, nil } -func (g *Generate) FlagSet() *flag.FlagSet { +func RoutesFileConfigurationFlagSet(g *muxt.RoutesFileConfiguration) *flag.FlagSet { flagSet := flag.NewFlagSet("generate", flag.ContinueOnError) - flagSet.StringVar(&g.OutputFilename, outputFlagName, muxt.DefaultOutputFileName, outputFlagNameHelp) + flagSet.StringVar(&g.OutputFileName, outputFlagName, muxt.DefaultOutputFileName, outputFlagNameHelp) flagSet.StringVar(&g.TemplatesVariable, templatesVariable, muxt.DefaultTemplatesVariableName, templatesVariableHelp) flagSet.StringVar(&g.RoutesFunction, routesFunc, muxt.DefaultRoutesFunctionName, routesFuncHelp) - flagSet.StringVar(&g.ReceiverIdent, receiverStaticType, "", receiverStaticTypeHelp) - flagSet.StringVar(&g.ReceiverStaticTypePackage, receiverStaticTypePackage, "", receiverStaticTypePackageHelp) - flagSet.StringVar(&g.ReceiverInterfaceIdent, receiverInterfaceName, muxt.DefaultReceiverInterfaceName, receiverInterfaceNameHelp) + flagSet.StringVar(&g.ReceiverType, receiverStaticType, "", receiverStaticTypeHelp) + flagSet.StringVar(&g.ReceiverPackage, receiverStaticTypePackage, "", receiverStaticTypePackageHelp) + flagSet.StringVar(&g.ReceiverInterface, receiverInterfaceName, muxt.DefaultReceiverInterfaceName, receiverInterfaceNameHelp) return flagSet } diff --git a/internal/configuration/generate_test.go b/internal/configuration/generate_test.go index 94a07b4..1624866 100644 --- a/internal/configuration/generate_test.go +++ b/internal/configuration/generate_test.go @@ -9,33 +9,33 @@ import ( func TestNewGenerate(t *testing.T) { t.Run("unknown flag", func(t *testing.T) { - _, err := NewGenerate([]string{ + _, err := NewRoutesFileConfiguration([]string{ "--unknown", - }, func(s string) string { return "" }, io.Discard) + }, io.Discard) assert.ErrorContains(t, err, "flag provided but not defined") }) t.Run(receiverStaticType+" flag value is an invalid identifier", func(t *testing.T) { - _, err := NewGenerate([]string{ + _, err := NewRoutesFileConfiguration([]string{ "--" + receiverStaticType, "123", - }, func(s string) string { return "" }, io.Discard) + }, io.Discard) assert.ErrorContains(t, err, errIdentSuffix) }) t.Run(routesFunc+" flag value is an invalid identifier", func(t *testing.T) { - _, err := NewGenerate([]string{ + _, err := NewRoutesFileConfiguration([]string{ "--" + routesFunc, "123", - }, func(s string) string { return "" }, io.Discard) + }, io.Discard) assert.ErrorContains(t, err, errIdentSuffix) }) t.Run(templatesVariable+" flag value is an invalid identifier", func(t *testing.T) { - _, err := NewGenerate([]string{ + _, err := NewRoutesFileConfiguration([]string{ "--" + templatesVariable, "123", - }, func(s string) string { return "" }, io.Discard) + }, io.Discard) assert.ErrorContains(t, err, errIdentSuffix) }) t.Run(outputFlagName+" flag value is not a go file", func(t *testing.T) { - _, err := NewGenerate([]string{ + _, err := NewRoutesFileConfiguration([]string{ "--" + outputFlagName, "output.txt", - }, func(s string) string { return "" }, io.Discard) + }, io.Discard) assert.ErrorContains(t, err, "filename must use .go extension") }) } diff --git a/routes.go b/routes.go index 1a2112b..c2b4992 100644 --- a/routes.go +++ b/routes.go @@ -55,20 +55,20 @@ const ( type RoutesFileConfiguration struct { executeFunc bool - Package, + PackageName, PackagePath, - TemplatesVar, - RoutesFunc, + TemplatesVariable, + RoutesFunction, ReceiverType, ReceiverPackage, ReceiverInterface, - Output string + OutputFileName string } func (config RoutesFileConfiguration) applyDefaults() RoutesFileConfiguration { - config.Package = cmp.Or(config.Package, defaultPackageName) - config.TemplatesVar = cmp.Or(config.TemplatesVar, DefaultTemplatesVariableName) - config.RoutesFunc = cmp.Or(config.RoutesFunc, DefaultRoutesFunctionName) + config.PackageName = cmp.Or(config.PackageName, defaultPackageName) + config.TemplatesVariable = cmp.Or(config.TemplatesVariable, DefaultTemplatesVariableName) + config.RoutesFunction = cmp.Or(config.RoutesFunction, DefaultRoutesFunctionName) config.ReceiverInterface = cmp.Or(config.ReceiverInterface, DefaultReceiverInterfaceName) config.executeFunc = true return config @@ -76,8 +76,8 @@ func (config RoutesFileConfiguration) applyDefaults() RoutesFileConfiguration { func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfiguration) (string, error) { config = config.applyDefaults() - if !token.IsIdentifier(config.Package) { - return "", fmt.Errorf("package name %q is not an identifier", config.Package) + if !token.IsIdentifier(config.PackageName) { + return "", fmt.Errorf("package name %q is not an identifier", config.PackageName) } imports := source.NewImports(&ast.GenDecl{Tok: token.IMPORT}) @@ -104,7 +104,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur } imports.SetOutputPackage(routesPkg.Types) config.PackagePath = routesPkg.PkgPath - config.Package = routesPkg.Name + config.PackageName = routesPkg.Name var receiver *types.Named if config.ReceiverType != "" { receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath) @@ -125,7 +125,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur receiver = types.NewNamed(types.NewTypeName(0, routesPkg.Types, "Receiver", nil), types.NewStruct(nil, nil), nil) } - ts, err := source.Templates(wd, config.TemplatesVar, routesPkg) + ts, err := source.Templates(wd, config.TemplatesVariable, routesPkg) if err != nil { return "", err } @@ -138,7 +138,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur if p.Types.Path() == config.PackagePath { if executeObj := p.Types.Scope().Lookup("execute"); executeObj != nil { if _, ok := executeObj.(*types.Func); ok { - config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.Output + config.executeFunc = filepath.Base(p.Fset.Position(executeObj.Pos()).Filename) == config.OutputFileName } } break @@ -150,7 +150,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur } routesFunc := &ast.FuncDecl{ - Name: ast.NewIdent(config.RoutesFunc), + Name: ast.NewIdent(config.RoutesFunction), Type: &ast.FuncType{ Params: &ast.FieldList{ List: []*ast.Field{ @@ -209,7 +209,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur imports.SortImports() file := &ast.File{ - Name: ast.NewIdent(config.Package), + Name: ast.NewIdent(config.PackageName), Decls: []ast.Decl{ // import imports.GenDecl, @@ -228,7 +228,7 @@ func TemplateRoutesFile(wd string, logger *log.Logger, config RoutesFileConfigur } if config.executeFunc { - file.Decls = append(file.Decls, executeFuncDecl(imports, config.TemplatesVar)) + file.Decls = append(file.Decls, executeFuncDecl(imports, config.TemplatesVariable)) } return source.Format(file), nil diff --git a/routes_test.go b/routes_test.go index be50965..1671aa2 100644 --- a/routes_test.go +++ b/routes_test.go @@ -1946,12 +1946,12 @@ var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml")) logger := log.New(io.Discard, "", 0) out, err := muxt.TemplateRoutesFile(dir, logger, muxt.RoutesFileConfiguration{ ReceiverInterface: tt.Interface, - Package: tt.PackageName, - TemplatesVar: tt.TemplatesVar, - RoutesFunc: tt.RoutesFunc, + PackageName: tt.PackageName, + TemplatesVariable: tt.TemplatesVar, + RoutesFunction: tt.RoutesFunc, PackagePath: "example.com", ReceiverType: tt.Receiver, - Output: "template_routes.go", + OutputFileName: "template_routes.go", }) if tt.ExpectedError == "" { require.NoError(t, err) @@ -1963,22 +1963,6 @@ var templates = template.Must(template.ParseFS(templatesDir, "template.gohtml")) } } -//func loadPackage(t *testing.T, in string) []*packages.Package { -// t.Helper() -// archive := txtar.Parse([]byte(in)) -// archiveDir, err := txtar.FS(archive) -// require.NoError(t, err) -// -// dir := t.TempDir() -// require.NoError(t, os.CopyFS(dir, archiveDir)) -// require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com\n"), 0644)) -// -// packageList, err := source.Load(dir, "./...") -// require.NoError(t, err) -// -// return packageList -//} - const executeGo = `-- execute.go -- package main