From 9a36b0e885c813a7f10697618aa5673a5a1aa422 Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Thu, 10 Oct 2024 18:44:05 -0400 Subject: [PATCH] cmd: add --log-level CLI flag Allow specifying the log level directly as the global CLI flag `--log-level`. --- go.mod | 3 + pkg/command/command.go | 74 +++++++++------- pkg/command/command_test.go | 169 ++++++++++++++++++++++++++++++++++++ pkg/command/testing.go | 115 ++++++++++++++++++++++++ 4 files changed, 330 insertions(+), 31 deletions(-) create mode 100644 pkg/command/command_test.go create mode 100644 pkg/command/testing.go diff --git a/go.mod b/go.mod index 3714ef5..26279e8 100644 --- a/go.mod +++ b/go.mod @@ -15,10 +15,12 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -28,6 +30,7 @@ require ( github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rs/zerolog v1.33.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/pkg/command/command.go b/pkg/command/command.go index 833c57b..be53d6a 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -45,6 +45,15 @@ type Command[T config.Config] struct { newConfig config.New[T] config T setupCommands []SetupCommand[T] + + format printer.Format + debug bool + logLevel types.Level + + // The following io.Writer values should be used when outputting text. They + // default to os.Stdout and os.Stderr but may be changed during tests. + stdout io.Writer + stderr io.Writer } var ( @@ -72,33 +81,32 @@ func New[T config.Config](cli string, short string, long string, noargs bool, ve version: version, newConfig: newConfig, setupCommands: setupCommands, + stdout: os.Stdout, + stderr: os.Stderr, } } func (c *Command[T]) Execute(ctx context.Context, commandType Type) int { - var format printer.Format - var debug bool - devEnv := fmt.Sprintf("%s_DISABLE_DEV_WARNING", strings.ToUpper(replacer.Replace(c.cli))) devWarning := fmt.Sprintf("!! WARNING: You are using a self-compiled binary which is not officially supported.\n!! To dismiss this warning, set %s=true\n\n", devEnv) if _, ok := os.LookupEnv(devEnv); !ok { if c.version.GitCommit() == "" || c.version.GoVersion() == "" || c.version.BuildDate() == "" || c.version.Version() == "" || c.version.Platform() == "" { - _, _ = fmt.Fprintf(os.Stderr, devWarning) + _, _ = fmt.Fprintf(c.stderr, devWarning) } } - err := c.runCmd(ctx, &format, &debug, commandType) + err := c.runCmd(ctx, commandType) if err == nil { return 0 } // print any user specific messages first - switch format { + switch c.format { case printer.JSON: - _, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + _, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err) default: - _, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err) + _, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err) } logClosersLock.Lock() @@ -118,7 +126,7 @@ func (c *Command[T]) Execute(ctx context.Context, commandType Type) int { // runCmd adds all child commands to the root command, sets flags // appropriately, and runs the root command. -func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug *bool, commandType Type) error { +func (c *Command[T]) runCmd(ctx context.Context, commandType Type) error { c.config = c.newConfig() configDir, err := c.config.DefaultConfigDir() @@ -147,29 +155,29 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug * cobra.OnInitialize(func() { err := c.initConfig() if err != nil { - switch *format { + switch c.format { case printer.JSON: - _, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + _, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err) default: - _, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err) + _, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err) } os.Exit(cmdutils.FatalErrExitCode) } - ch.SetDebug(debug) + ch.SetDebug(&c.debug) - ch.Printer = printer.NewPrinter(format) + ch.Printer = printer.NewPrinter(&c.format) if strings.TrimSpace(logFile) == "" { - logOutput = os.Stderr + logOutput = c.stderr } else { if err := os.MkdirAll(filepath.Dir(logFile), 0700); err != nil { - switch *format { + switch c.format { case printer.JSON: - _, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + _, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err) default: - _, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err) + _, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err) } os.Exit(cmdutils.FatalErrExitCode) @@ -177,11 +185,11 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug * fileLogOutput, err := os.OpenFile(logFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0700) if err != nil { - switch *format { + switch c.format { case printer.JSON: - _, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + _, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err) default: - _, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err) + _, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err) } os.Exit(cmdutils.FatalErrExitCode) @@ -192,24 +200,19 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug * logClosers = append(logClosers, fileLogOutput.Close) if ch.Debug() { - logOutput = io.MultiWriter(fileLogOutput, os.Stderr) + logOutput = io.MultiWriter(fileLogOutput, c.stderr) } else { logOutput = fileLogOutput } } - switch *format { + switch c.format { case printer.JSON: ch.Logger = logging.New(logging.Zerolog, strings.ToLower(c.cli), logOutput) default: ch.Logger = logging.New(logging.Slog, strings.ToLower(c.cli), logOutput) } - - if ch.Debug() { - ch.Logger.SetLevel(types.TraceLevel) - } else { - ch.Logger.SetLevel(types.InfoLevel) - } + ch.Logger.SetLevel(c.logLevel) }) c.command.SilenceUsage = true @@ -222,7 +225,7 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug * c.config.RootPersistentFlags(c.command.PersistentFlags()) - c.command.PersistentFlags().VarP(printer.NewFormatValue(printer.Human, format), "format", "f", "Show output in a specific format. Possible values: [human, json]") + c.command.PersistentFlags().VarP(printer.NewFormatValue(printer.Human, &c.format), "format", "f", "Show output in a specific format. Possible values: [human, json]") if err = viper.BindPFlag("format", c.command.PersistentFlags().Lookup("format")); err != nil { return err } @@ -230,11 +233,20 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug * return []string{"human", "json"}, cobra.ShellCompDirectiveDefault }) - c.command.PersistentFlags().BoolVar(debug, "debug", false, "Enable debug mode") + c.command.PersistentFlags().BoolVar(&c.debug, "debug", false, "Enable debug mode") if err = viper.BindPFlag("debug", c.command.PersistentFlags().Lookup("debug")); err != nil { return err } + c.logLevel = types.InfoLevel + c.command.PersistentFlags().VarP(&c.logLevel, "log-level", "", "") + if err = viper.BindPFlag("log-level", c.command.PersistentFlags().Lookup("debug")); err != nil { + return err + } + _ = c.command.RegisterFlagCompletionFunc("log-level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return []string{"fatal", "error", "warn", "info", "debug", "trace"}, cobra.ShellCompDirectiveDefault + }) + c.command.PersistentFlags().BoolVar(&color.NoColor, "no-color", false, "Disable color output") if err = viper.BindPFlag("no-color", c.command.PersistentFlags().Lookup("no-color")); err != nil { return err diff --git a/pkg/command/command_test.go b/pkg/command/command_test.go new file mode 100644 index 0000000..1ff28cf --- /dev/null +++ b/pkg/command/command_test.go @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: Apache-2.0 + +package command + +import ( + "context" + "fmt" + "testing" + + "github.com/loopholelabs/cmdutils" + "github.com/loopholelabs/logging/loggers/zerolog" + "github.com/loopholelabs/logging/types" + "github.com/stretchr/testify/require" +) + +func TestLogLevel(t *testing.T) { + t.Setenv("TEST_DISABLE_DEV_WARNING", "true") + + testCases := []struct { + name string + args []string + expectedLevel types.Level + expectError bool + }{ + { + name: "default is info", + args: []string{"run"}, + expectedLevel: types.InfoLevel, + }, + { + name: "case insensitive", + args: []string{"run", "--log-level", "WarN"}, + expectedLevel: types.WarnLevel, + }, + { + name: "missing", + args: []string{"run", "--log-level", ""}, + expectError: true, + }, + { + name: "invalid", + args: []string{"run", "--log-level", "not-valid"}, + expectError: true, + }, + { + name: "trace", + args: []string{"run", "--log-level", "trace"}, + expectedLevel: types.TraceLevel, + }, + { + name: "debug", + args: []string{"run", "--log-level", "debug"}, + expectedLevel: types.DebugLevel, + }, + { + name: "info", + args: []string{"run", "--log-level", "info"}, + expectedLevel: types.InfoLevel, + }, + { + name: "warn", + args: []string{"run", "--log-level", "warn"}, + expectedLevel: types.WarnLevel, + }, + { + name: "error", + args: []string{"run", "--log-level", "error"}, + expectedLevel: types.ErrorLevel, + }, + { + name: "fatal", + args: []string{"run", "--log-level", "fatal"}, + expectedLevel: types.FatalLevel, + }, + } + + fn := func(ch *cmdutils.Helper[*TestConfig]) error { + ch.Logger.Trace().Msg("TRACE") + ch.Logger.Debug().Msg("DEBUG") + ch.Logger.Info().Msg("INFO") + ch.Logger.Warn().Msg("WARN") + ch.Logger.Error().Msg("ERROR") + + // Skip FATAL level with zerolog because it calls os.Exit(). + if _, ok := ch.Logger.(*zerolog.Logger); !ok { + ch.Logger.Fatal().Msg("FATAL") + } + return nil + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%s/default", tc.name), func(t *testing.T) { + h := NewTestCommandHarness(t, fn) + + rc := h.Execute(context.Background(), tc.args) + if tc.expectError { + require.NotZero(t, rc) + require.Contains(t, h.Stderr(), "--log-level") + return + } + + require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr()) + require.Equal(t, tc.expectedLevel, h.cmd.logLevel) + + for l := types.FatalLevel; l <= types.TraceLevel; l++ { + msg := fmt.Sprintf("msg=%s", l) + if tc.expectedLevel >= l { + require.Contains(t, h.Stderr(), msg) + } else { + require.NotContains(t, h.Stderr(), msg) + } + } + }) + + t.Run(fmt.Sprintf("%s/human", tc.name), func(t *testing.T) { + h := NewTestCommandHarness(t, fn) + + args := make([]string, len(tc.args)) + copy(args, tc.args) + args = append(args, "--format=human") + + rc := h.Execute(context.Background(), args) + if tc.expectError { + require.NotZero(t, rc) + require.Contains(t, h.Stderr(), "--log-level") + return + } + + require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr()) + require.Equal(t, tc.expectedLevel, h.cmd.logLevel) + + for l := types.FatalLevel; l <= types.TraceLevel; l++ { + msg := fmt.Sprintf("msg=%s", l) + if tc.expectedLevel >= l { + require.Contains(t, h.Stderr(), msg) + } else { + require.NotContains(t, h.Stderr(), msg) + } + } + }) + + t.Run(fmt.Sprintf("%s/json", tc.name), func(t *testing.T) { + h := NewTestCommandHarness(t, fn) + + args := make([]string, len(tc.args)) + copy(args, tc.args) + args = append(args, "--format=json") + + rc := h.Execute(context.Background(), args) + if tc.expectError { + require.NotZero(t, rc) + require.Contains(t, h.Stderr(), "--log-level") + return + } + + require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr()) + require.Equal(t, tc.expectedLevel, h.cmd.logLevel) + + for l := types.ErrorLevel; l <= types.TraceLevel; l++ { + msg := fmt.Sprintf(`"message":"%s"`, l) + if tc.expectedLevel >= l { + require.Contains(t, h.Stderr(), msg) + } else { + require.NotContains(t, h.Stderr(), msg) + } + } + }) + } +} diff --git a/pkg/command/testing.go b/pkg/command/testing.go new file mode 100644 index 0000000..7508862 --- /dev/null +++ b/pkg/command/testing.go @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 + +package command + +import ( + "bytes" + "context" + "path/filepath" + "testing" + + "github.com/loopholelabs/cmdutils" + "github.com/loopholelabs/cmdutils/pkg/version" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +type testCommandFn func(*cmdutils.Helper[*TestConfig]) error + +type TestCommandHarness struct { + t *testing.T + cmd *Command[*TestConfig] + stdout *bytes.Buffer + stderr *bytes.Buffer + config *TestConfig + defaultConfigFile string + defaultLogFile string + fn testCommandFn +} + +func NewTestCommandHarness(t *testing.T, fn testCommandFn) *TestCommandHarness { + h := &TestCommandHarness{ + t: t, + fn: fn, + stdout: new(bytes.Buffer), + stderr: new(bytes.Buffer), + defaultConfigFile: filepath.Join(t.TempDir(), "test-config.yml"), + defaultLogFile: filepath.Join(t.TempDir(), "test-log.log"), + } + + h.cmd = New("test", "A CLI test", "CLI test", true, + version.New[*TestConfig]("", "", "", "", ""), + NewTestConfigFn(h.defaultConfigFile, h.defaultLogFile), + []SetupCommand[*TestConfig]{ + h.setupCommandRun, + }, + ) + h.cmd.stdout = h.stdout + h.cmd.stderr = h.stderr + + return h +} + +func (h *TestCommandHarness) setupCommandRun(rootCmd *cobra.Command, ch *cmdutils.Helper[*TestConfig]) { + cmd := &cobra.Command{ + Use: "run", + Short: "Run test function", + RunE: func(cmd *cobra.Command, args []string) error { + h.config = ch.Config + if h.fn != nil { + return h.fn(ch) + } + return nil + }, + } + cmd.Flags().StringVar(&ch.Config.MyConfig, "my-config", "", "A sample test configuration") + + rootCmd.AddCommand(cmd) +} + +func (h *TestCommandHarness) Execute(ctx context.Context, args []string) int { + h.cmd.command.SetArgs(args) + return h.cmd.Execute(ctx, Noninteractive) +} + +func (h *TestCommandHarness) Stdout() string { + return h.stdout.String() +} + +func (h *TestCommandHarness) Stderr() string { + return h.stderr.String() +} + +type TestConfig struct { + cfgFile string + logFile string + + // Common configuration values. + Format string + Debug bool + NoColor bool `mapstructure:"no-color"` + + // Custom configuration values. + MyConfig string `mapstructure:"my-config"` +} + +func NewTestConfigFn(cfgFile string, logFile string) func() *TestConfig { + return func() *TestConfig { + return &TestConfig{ + cfgFile: cfgFile, + logFile: logFile, + } + } +} + +func (_ *TestConfig) RootPersistentFlags(flags *pflag.FlagSet) { return } +func (_ *TestConfig) GlobalRequiredFlags(cmd *cobra.Command) error { return nil } +func (_ *TestConfig) Validate() error { return nil } +func (c *TestConfig) DefaultConfigDir() (string, error) { return filepath.Dir(c.cfgFile), nil } +func (c *TestConfig) DefaultConfigFile() string { return c.cfgFile } +func (c *TestConfig) DefaultLogDir() (string, error) { return filepath.Dir(c.logFile), nil } +func (c *TestConfig) DefaultLogFile() string { return c.logFile } +func (c *TestConfig) SetConfigFile(cfg string) { c.cfgFile = cfg } +func (c *TestConfig) GetConfigFile() string { return c.cfgFile } +func (c *TestConfig) SetLogFile(l string) { c.logFile = l } +func (c *TestConfig) GetLogFile() string { return c.logFile }