Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let the commands store flagComp functions internally (and avoid global state) #2012

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
8 changes: 5 additions & 3 deletions bash_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,10 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) {

// prepareCustomAnnotationsForFlags setup annotations for go completions for registered flags
func prepareCustomAnnotationsForFlags(cmd *Command) {
flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()
for flag := range flagCompletionFunctions {
cmd.initializeCompletionStorage()
cmd.flagCompletionMutex.RLock()
defer cmd.flagCompletionMutex.RUnlock()
for flag := range cmd.flagCompletionFunctions {
// Make sure the completion script calls the __*_go_custom_completion function for
// every registered flag. We need to do this here (and not when the flag was registered
// for completion) so that we can know the root command name for the prefix
Expand Down Expand Up @@ -644,6 +645,7 @@ func writeCmdAliases(buf io.StringWriter, cmd *Command) {
WriteStringAndCheck(buf, ` fi`)
WriteStringAndCheck(buf, "\n")
}

func writeArgAliases(buf io.StringWriter, cmd *Command) {
WriteStringAndCheck(buf, " noun_aliases=()\n")
sort.Strings(cmd.ArgAliases)
Expand Down
8 changes: 8 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"path/filepath"
"sort"
"strings"
"sync"

flag "github.com/spf13/pflag"
)
Expand Down Expand Up @@ -163,6 +164,13 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName

// flagsCompletions contrains completions for arbitrary lists of flags.
// Those flags may or may not actually strictly belong to the command in the function,
// but registering completions for them through the command allows for garbage-collecting.
flagCompletionFunctions map[*flag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
maxlandon marked this conversation as resolved.
Show resolved Hide resolved
// lock for reading and writing from flagCompletionFunctions
flagCompletionMutex *sync.RWMutex
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
Expand Down
63 changes: 42 additions & 21 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ const (
ShellCompNoDescRequestCmd = "__completeNoDesc"
)

// Global map of flag completion functions. Make sure to use flagCompletionMutex before you try to read and write from it.
var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){}

// lock for reading and writing from flagCompletionFunctions
var flagCompletionMutex = &sync.RWMutex{}

// ShellCompDirective is a bit map representing the different behaviors the shell
// can be instructed to have once completions have been provided.
type ShellCompDirective int
Expand Down Expand Up @@ -135,23 +129,41 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
if flag == nil {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName)
}
flagCompletionMutex.Lock()
defer flagCompletionMutex.Unlock()
// Ensure none of our relevant fields are nil.
c.initializeCompletionStorage()
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

if _, exists := flagCompletionFunctions[flag]; exists {
c.flagCompletionMutex.Lock()
defer c.flagCompletionMutex.Unlock()

// And attempt to bind the completion.
if _, exists := c.flagCompletionFunctions[flag]; exists {
return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName)
}
flagCompletionFunctions[flag] = f
c.flagCompletionFunctions[flag] = f
return nil
}

// GetFlagCompletion returns the completion function for the given flag, if available.
func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
flagCompletionMutex.RLock()
defer flagCompletionMutex.RUnlock()
func (c *Command) GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to a command method.

c.initializeCompletionStorage()
maxlandon marked this conversation as resolved.
Show resolved Hide resolved

c.flagCompletionMutex.RLock()
defer c.flagCompletionMutex.RUnlock()

completionFunc, exists := c.flagCompletionFunctions[flag]

// If found it here, return now
if completionFunc != nil && exists {
return completionFunc, exists
}

// If we are already at the root command level, return anyway
if !c.HasParent() {
return nil, false
}

completionFunc, exists := flagCompletionFunctions[flag]
return completionFunc, exists
// Or walk up the command tree.
return c.Parent().GetFlagCompletion(flag)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Recursively walk up the command parents' tree until flagComp is found.
    Behavior identical regardless of if the flagComp is searched with a flag pointer, or a flag name.

}

// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available.
Expand All @@ -161,7 +173,19 @@ func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, a
return nil, false
}

return GetFlagCompletion(flag)
return c.GetFlagCompletion(flag)
}

// initializeCompletionStorage is (and should be) called in all
// functions that make use of the command's flag completion functions.
func (c *Command) initializeCompletionStorage() {
maxlandon marked this conversation as resolved.
Show resolved Hide resolved
if c.flagCompletionMutex == nil {
c.flagCompletionMutex = new(sync.RWMutex)
}

if c.flagCompletionFunctions == nil {
c.flagCompletionFunctions = make(map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), 0)
}
}

// Returns a string listing the different directive enabled in the specified parameter
Expand Down Expand Up @@ -507,9 +531,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
// Find the completion function for the flag or command
var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective)
if flag != nil && flagCompletion {
flagCompletionMutex.RLock()
completionFn = flagCompletionFunctions[flag]
flagCompletionMutex.RUnlock()
completionFn, _ = finalCmd.GetFlagCompletion(flag)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Use our new recursive methods to lookup what to complete when __complete is called.
    Passes all tests

} else {
completionFn = finalCmd.ValidArgsFunction
}
Expand Down Expand Up @@ -833,7 +855,6 @@ to your powershell profile.
return cmd.Root().GenPowerShellCompletion(out)
}
return cmd.Root().GenPowerShellCompletionWithDesc(out)

},
}
if haveNoDescFlag {
Expand Down Expand Up @@ -873,7 +894,7 @@ func CompDebug(msg string, printToStdErr bool) {
// variable BASH_COMP_DEBUG_FILE to the path of some file to be used.
if path := os.Getenv("BASH_COMP_DEBUG_FILE"); path != "" {
f, err := os.OpenFile(path,
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err == nil {
defer f.Close()
WriteStringAndCheck(f, msg)
Expand Down
Loading