Skip to content

Commit

Permalink
cmd: add --log-level CLI flag
Browse files Browse the repository at this point in the history
Allow specifying the log level directly as the global CLI flag
`--log-level`.
  • Loading branch information
lgfa29 committed Oct 15, 2024
1 parent 9298e14 commit 9a36b0e
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 31 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ require (
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
Expand All @@ -28,6 +30,7 @@ require (
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rs/zerolog v1.33.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
Expand Down
74 changes: 43 additions & 31 deletions pkg/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ type Command[T config.Config] struct {
newConfig config.New[T]
config T
setupCommands []SetupCommand[T]

format printer.Format
debug bool
logLevel types.Level

// The following io.Writer values should be used when outputting text. They
// default to os.Stdout and os.Stderr but may be changed during tests.
stdout io.Writer
stderr io.Writer
}

var (
Expand Down Expand Up @@ -72,33 +81,32 @@ func New[T config.Config](cli string, short string, long string, noargs bool, ve
version: version,
newConfig: newConfig,
setupCommands: setupCommands,
stdout: os.Stdout,
stderr: os.Stderr,
}
}

func (c *Command[T]) Execute(ctx context.Context, commandType Type) int {
var format printer.Format
var debug bool

devEnv := fmt.Sprintf("%s_DISABLE_DEV_WARNING", strings.ToUpper(replacer.Replace(c.cli)))
devWarning := fmt.Sprintf("!! WARNING: You are using a self-compiled binary which is not officially supported.\n!! To dismiss this warning, set %s=true\n\n", devEnv)

if _, ok := os.LookupEnv(devEnv); !ok {
if c.version.GitCommit() == "" || c.version.GoVersion() == "" || c.version.BuildDate() == "" || c.version.Version() == "" || c.version.Platform() == "" {
_, _ = fmt.Fprintf(os.Stderr, devWarning)
_, _ = fmt.Fprintf(c.stderr, devWarning)
}
}

err := c.runCmd(ctx, &format, &debug, commandType)
err := c.runCmd(ctx, commandType)
if err == nil {
return 0
}

// print any user specific messages first
switch format {
switch c.format {
case printer.JSON:
_, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err)
_, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err)
default:
_, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err)
_, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err)
}

logClosersLock.Lock()
Expand All @@ -118,7 +126,7 @@ func (c *Command[T]) Execute(ctx context.Context, commandType Type) int {

// runCmd adds all child commands to the root command, sets flags
// appropriately, and runs the root command.
func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug *bool, commandType Type) error {
func (c *Command[T]) runCmd(ctx context.Context, commandType Type) error {
c.config = c.newConfig()

configDir, err := c.config.DefaultConfigDir()
Expand Down Expand Up @@ -147,41 +155,41 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug *
cobra.OnInitialize(func() {
err := c.initConfig()
if err != nil {
switch *format {
switch c.format {
case printer.JSON:
_, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err)
_, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err)
default:
_, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err)
_, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err)
}

os.Exit(cmdutils.FatalErrExitCode)
}

ch.SetDebug(debug)
ch.SetDebug(&c.debug)

ch.Printer = printer.NewPrinter(format)
ch.Printer = printer.NewPrinter(&c.format)

if strings.TrimSpace(logFile) == "" {
logOutput = os.Stderr
logOutput = c.stderr
} else {
if err := os.MkdirAll(filepath.Dir(logFile), 0700); err != nil {
switch *format {
switch c.format {
case printer.JSON:
_, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err)
_, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err)
default:
_, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err)
_, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err)
}

os.Exit(cmdutils.FatalErrExitCode)
}

fileLogOutput, err := os.OpenFile(logFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0700)
if err != nil {
switch *format {
switch c.format {
case printer.JSON:
_, _ = fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err)
_, _ = fmt.Fprintf(c.stderr, `{"error": "%s"}`, err)
default:
_, _ = fmt.Fprintf(os.Stderr, "Error: %s\n", err)
_, _ = fmt.Fprintf(c.stderr, "Error: %s\n", err)
}

os.Exit(cmdutils.FatalErrExitCode)
Expand All @@ -192,24 +200,19 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug *
logClosers = append(logClosers, fileLogOutput.Close)

if ch.Debug() {
logOutput = io.MultiWriter(fileLogOutput, os.Stderr)
logOutput = io.MultiWriter(fileLogOutput, c.stderr)
} else {
logOutput = fileLogOutput
}
}

switch *format {
switch c.format {
case printer.JSON:
ch.Logger = logging.New(logging.Zerolog, strings.ToLower(c.cli), logOutput)
default:
ch.Logger = logging.New(logging.Slog, strings.ToLower(c.cli), logOutput)
}

if ch.Debug() {
ch.Logger.SetLevel(types.TraceLevel)
} else {
ch.Logger.SetLevel(types.InfoLevel)
}
ch.Logger.SetLevel(c.logLevel)
})

c.command.SilenceUsage = true
Expand All @@ -222,19 +225,28 @@ func (c *Command[T]) runCmd(ctx context.Context, format *printer.Format, debug *

c.config.RootPersistentFlags(c.command.PersistentFlags())

c.command.PersistentFlags().VarP(printer.NewFormatValue(printer.Human, format), "format", "f", "Show output in a specific format. Possible values: [human, json]")
c.command.PersistentFlags().VarP(printer.NewFormatValue(printer.Human, &c.format), "format", "f", "Show output in a specific format. Possible values: [human, json]")
if err = viper.BindPFlag("format", c.command.PersistentFlags().Lookup("format")); err != nil {
return err
}
_ = c.command.RegisterFlagCompletionFunc("format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return []string{"human", "json"}, cobra.ShellCompDirectiveDefault
})

c.command.PersistentFlags().BoolVar(debug, "debug", false, "Enable debug mode")
c.command.PersistentFlags().BoolVar(&c.debug, "debug", false, "Enable debug mode")
if err = viper.BindPFlag("debug", c.command.PersistentFlags().Lookup("debug")); err != nil {
return err
}

c.logLevel = types.InfoLevel
c.command.PersistentFlags().VarP(&c.logLevel, "log-level", "", "")
if err = viper.BindPFlag("log-level", c.command.PersistentFlags().Lookup("debug")); err != nil {
return err
}
_ = c.command.RegisterFlagCompletionFunc("log-level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return []string{"fatal", "error", "warn", "info", "debug", "trace"}, cobra.ShellCompDirectiveDefault
})

c.command.PersistentFlags().BoolVar(&color.NoColor, "no-color", false, "Disable color output")
if err = viper.BindPFlag("no-color", c.command.PersistentFlags().Lookup("no-color")); err != nil {
return err
Expand Down
169 changes: 169 additions & 0 deletions pkg/command/command_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// SPDX-License-Identifier: Apache-2.0

package command

import (
"context"
"fmt"
"testing"

"github.com/loopholelabs/cmdutils"
"github.com/loopholelabs/logging/loggers/zerolog"
"github.com/loopholelabs/logging/types"
"github.com/stretchr/testify/require"
)

func TestLogLevel(t *testing.T) {
t.Setenv("TEST_DISABLE_DEV_WARNING", "true")

testCases := []struct {
name string
args []string
expectedLevel types.Level
expectError bool
}{
{
name: "default is info",
args: []string{"run"},
expectedLevel: types.InfoLevel,
},
{
name: "case insensitive",
args: []string{"run", "--log-level", "WarN"},
expectedLevel: types.WarnLevel,
},
{
name: "missing",
args: []string{"run", "--log-level", ""},
expectError: true,
},
{
name: "invalid",
args: []string{"run", "--log-level", "not-valid"},
expectError: true,
},
{
name: "trace",
args: []string{"run", "--log-level", "trace"},
expectedLevel: types.TraceLevel,
},
{
name: "debug",
args: []string{"run", "--log-level", "debug"},
expectedLevel: types.DebugLevel,
},
{
name: "info",
args: []string{"run", "--log-level", "info"},
expectedLevel: types.InfoLevel,
},
{
name: "warn",
args: []string{"run", "--log-level", "warn"},
expectedLevel: types.WarnLevel,
},
{
name: "error",
args: []string{"run", "--log-level", "error"},
expectedLevel: types.ErrorLevel,
},
{
name: "fatal",
args: []string{"run", "--log-level", "fatal"},
expectedLevel: types.FatalLevel,
},
}

fn := func(ch *cmdutils.Helper[*TestConfig]) error {
ch.Logger.Trace().Msg("TRACE")
ch.Logger.Debug().Msg("DEBUG")
ch.Logger.Info().Msg("INFO")
ch.Logger.Warn().Msg("WARN")
ch.Logger.Error().Msg("ERROR")

// Skip FATAL level with zerolog because it calls os.Exit().
if _, ok := ch.Logger.(*zerolog.Logger); !ok {
ch.Logger.Fatal().Msg("FATAL")
}
return nil
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("%s/default", tc.name), func(t *testing.T) {
h := NewTestCommandHarness(t, fn)

rc := h.Execute(context.Background(), tc.args)
if tc.expectError {
require.NotZero(t, rc)
require.Contains(t, h.Stderr(), "--log-level")
return
}

require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr())
require.Equal(t, tc.expectedLevel, h.cmd.logLevel)

for l := types.FatalLevel; l <= types.TraceLevel; l++ {
msg := fmt.Sprintf("msg=%s", l)
if tc.expectedLevel >= l {
require.Contains(t, h.Stderr(), msg)
} else {
require.NotContains(t, h.Stderr(), msg)
}
}
})

t.Run(fmt.Sprintf("%s/human", tc.name), func(t *testing.T) {
h := NewTestCommandHarness(t, fn)

args := make([]string, len(tc.args))
copy(args, tc.args)
args = append(args, "--format=human")

rc := h.Execute(context.Background(), args)
if tc.expectError {
require.NotZero(t, rc)
require.Contains(t, h.Stderr(), "--log-level")
return
}

require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr())
require.Equal(t, tc.expectedLevel, h.cmd.logLevel)

for l := types.FatalLevel; l <= types.TraceLevel; l++ {
msg := fmt.Sprintf("msg=%s", l)
if tc.expectedLevel >= l {
require.Contains(t, h.Stderr(), msg)
} else {
require.NotContains(t, h.Stderr(), msg)
}
}
})

t.Run(fmt.Sprintf("%s/json", tc.name), func(t *testing.T) {
h := NewTestCommandHarness(t, fn)

args := make([]string, len(tc.args))
copy(args, tc.args)
args = append(args, "--format=json")

rc := h.Execute(context.Background(), args)
if tc.expectError {
require.NotZero(t, rc)
require.Contains(t, h.Stderr(), "--log-level")
return
}

require.Zero(t, rc, "expected no error, got:\n%s", h.Stderr())
require.Equal(t, tc.expectedLevel, h.cmd.logLevel)

for l := types.ErrorLevel; l <= types.TraceLevel; l++ {
msg := fmt.Sprintf(`"message":"%s"`, l)
if tc.expectedLevel >= l {
require.Contains(t, h.Stderr(), msg)
} else {
require.NotContains(t, h.Stderr(), msg)
}
}
})
}
}
Loading

0 comments on commit 9a36b0e

Please sign in to comment.