diff --git a/completions.go b/completions.go index 95e1dc1fb..92b9c8f08 100644 --- a/completions.go +++ b/completions.go @@ -17,6 +17,7 @@ package cobra import ( "fmt" "os" + "regexp" "strings" "sync" @@ -895,3 +896,31 @@ func CompError(msg string) { func CompErrorln(msg string) { CompError(fmt.Sprintf("%s\n", msg)) } + +// configEnvVarGlobalPrefix should not be changed: users will be using it explicitly. +const configEnvVarGlobalPrefix = "COBRA" + +var configEnvVarPrefixSubstRegexp = regexp.MustCompile(`[^A-Z0-9_]`) + +// configEnvVar returns the name of the program-specific configuration environment +// variable. It has the format _ where is the name of the +// root command in upper case, with all non-ASCII-alphanumeric characters replaced by `_`. +func configEnvVar(name, suffix string) string { + // This format should not be changed: users will be using it explicitly. + v := strings.ToUpper(fmt.Sprintf("%s_%s", name, suffix)) + v = configEnvVarPrefixSubstRegexp.ReplaceAllString(v, "_") + return v +} + +// GetEnvConfig returns the value of the configuration environment variable +// _ where is the name of the root command in upper +// case, with all non-ASCII-alphanumeric characters replaced by `_`. +// If the value is empty or not set, the value of the environment variable +// COBRA_ is returned instead. +func GetEnvConfig(cmd *Command, suffix string) string { + v := os.Getenv(configEnvVar(cmd.Root().Name(), suffix)) + if v == "" { + v = os.Getenv(configEnvVar(configEnvVarGlobalPrefix, suffix)) + } + return v +} diff --git a/completions_test.go b/completions_test.go index 017246600..9f7c2ad72 100644 --- a/completions_test.go +++ b/completions_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "os" "strings" "testing" ) @@ -3277,3 +3278,78 @@ Completion ended with directive: ShellCompDirectiveNoFileComp }) } } + +func TestGetEnvConfig(t *testing.T) { + testCases := []struct { + use string + suffix string + cmdVar string + globalVar string + cmdVal string + globalVal string + expected string + }{ + { + use: "root", + suffix: "test", + cmdVar: "ROOT_TEST", + globalVar: "COBRA_TEST", + cmdVal: "cmd", + globalVal: "global", + expected: "cmd", + }, + { + use: "root", + suffix: "test", + cmdVar: "ROOT_TEST", + globalVar: "COBRA_TEST", + cmdVal: "", + globalVal: "global", + expected: "global", + }, + { + use: "root", + suffix: "test", + cmdVar: "ROOT_TEST", + globalVar: "COBRA_TEST", + cmdVal: "", + globalVal: "", + expected: "", + }, + { + use: "foo.bar", + suffix: "test", + cmdVar: "FOO_BAR_TEST", + globalVar: "COBRA_TEST", + cmdVal: "cmd", + globalVal: "global", + expected: "cmd", + }, + { + use: "quux-BAZ", + suffix: "test", + cmdVar: "QUUX_BAZ_TEST", + globalVar: "COBRA_TEST", + cmdVal: "cmd", + globalVal: "global", + expected: "cmd", + }, + } + + for _, tc := range testCases { + // Could make env handling cleaner with t.Setenv with Go >= 1.17 + func() { + err := os.Setenv(tc.cmdVar, tc.cmdVal) + defer assertNoErr(t, os.Unsetenv(tc.cmdVar)) + assertNoErr(t, err) + err = os.Setenv(tc.globalVar, tc.globalVal) + defer assertNoErr(t, os.Unsetenv(tc.globalVar)) + assertNoErr(t, err) + cmd := &Command{Use: tc.use} + got := GetEnvConfig(cmd, tc.suffix) + if got != tc.expected { + t.Errorf("expected: %q, got: %q", tc.expected, got) + } + }() + } +}