From 589693b84870c0228d123904ff62fd4bca24d301 Mon Sep 17 00:00:00 2001 From: Volodymyr Khoroz Date: Thu, 12 Oct 2023 20:35:57 +0300 Subject: [PATCH] Feature: allow running persistent run hooks of all parents Currently, only one of the persistent pre-runs and post-runs is executed. It is always the first one found in the parents chain, starting at this command. Expected behavior is to execute all parents' persistent pre-runs and post-runs. Dependent projects implemented various workarounds for this: - manually building persistent hook chains (in every hook). - applying some kind of monkey-patching on top of Cobra. This change eliminates the necessity for such workarounds by allowing to set a global variable EnableTraverseRunHooks. Tickets: - https://github.com/spf13/cobra/issues/216 - https://github.com/spf13/cobra/issues/252 Signed-off-by: Volodymyr Khoroz --- cobra.go | 11 +++-- command.go | 28 +++++++++++-- command_test.go | 106 +++++++++++++++++++++--------------------------- 3 files changed, 79 insertions(+), 66 deletions(-) diff --git a/cobra.go b/cobra.go index f23f5092f6..381c7b37d0 100644 --- a/cobra.go +++ b/cobra.go @@ -43,9 +43,10 @@ var initializers []func() var finalizers []func() const ( - defaultPrefixMatching = false - defaultCommandSorting = true - defaultCaseInsensitive = false + defaultPrefixMatching = false + defaultCommandSorting = true + defaultCaseInsensitive = false + defaultTraverseRunHooks = false ) // EnablePrefixMatching allows setting automatic prefix matching. Automatic prefix matching can be a dangerous thing @@ -60,6 +61,10 @@ var EnableCommandSorting = defaultCommandSorting // EnableCaseInsensitive allows case-insensitive commands names. (case sensitive by default) var EnableCaseInsensitive = defaultCaseInsensitive +// EnableTraverseRunHooks executes persistent pre-run and post-run hooks from all parents. +// By default, only the first run hook to be found is executed. +var EnableTraverseRunHooks = defaultTraverseRunHooks + // MousetrapHelpText enables an information splash screen on Windows // if the CLI is started from explorer.exe. // To disable the mousetrap, just set this variable to blank string (""). diff --git a/command.go b/command.go index 909d8585ef..ae3e4e0080 100644 --- a/command.go +++ b/command.go @@ -934,15 +934,31 @@ func (c *Command) execute(a []string) (err error) { return err } + parents := make([]*Command, 0, 5) for p := c; p != nil; p = p.Parent() { + if EnableTraverseRunHooks { + // When EnableTraverseRunHooks is set: + // - Execute all persistent pre-runs from the root parent till this command. + // - Execute all persistent post-runs from this command till the root parent. + parents = append([]*Command{p}, parents...) + } else { + // Otherwise, execute only the first found persistent hook. + parents = append(parents, p) + } + } + for _, p := range parents { if p.PersistentPreRunE != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil { return err } - break + if !EnableTraverseRunHooks { + break + } } else if p.PersistentPreRun != nil { p.PersistentPreRun(c, argWoFlags) - break + if !EnableTraverseRunHooks { + break + } } } if c.PreRunE != nil { @@ -979,10 +995,14 @@ func (c *Command) execute(a []string) (err error) { if err := p.PersistentPostRunE(c, argWoFlags); err != nil { return err } - break + if !EnableTraverseRunHooks { + break + } } else if p.PersistentPostRun != nil { p.PersistentPostRun(c, argWoFlags) - break + if !EnableTraverseRunHooks { + break + } } } diff --git a/command_test.go b/command_test.go index 4afb7f7b8e..4d8908d2fc 100644 --- a/command_test.go +++ b/command_test.go @@ -1530,57 +1530,73 @@ func TestHooks(t *testing.T) { } func TestPersistentHooks(t *testing.T) { - var ( - parentPersPreArgs string - parentPreArgs string - parentRunArgs string - parentPostArgs string - parentPersPostArgs string - ) + EnableTraverseRunHooks = true + testPersistentHooks(t, []string{ + "parent PersistentPreRun", + "child PersistentPreRun", + "child PreRun", + "child Run", + "child PostRun", + "child PersistentPostRun", + "parent PersistentPostRun", + }) - var ( - childPersPreArgs string - childPreArgs string - childRunArgs string - childPostArgs string - childPersPostArgs string - ) + EnableTraverseRunHooks = false + testPersistentHooks(t, []string{ + "child PersistentPreRun", + "child PreRun", + "child Run", + "child PostRun", + "child PersistentPostRun", + }) +} + +func testPersistentHooks(t *testing.T, expectedHookRunOrder []string) { + var hookRunOrder []string + + validateHook := func(args []string, hookName string) { + hookRunOrder = append(hookRunOrder, hookName) + got := strings.Join(args, " ") + if onetwo != got { + t.Errorf("Expected %s %q, got %q", hookName, onetwo, got) + } + } parentCmd := &Command{ Use: "parent", PersistentPreRun: func(_ *Command, args []string) { - parentPersPreArgs = strings.Join(args, " ") + validateHook(args, "parent PersistentPreRun") }, PreRun: func(_ *Command, args []string) { - parentPreArgs = strings.Join(args, " ") + validateHook(args, "parent PreRun") }, Run: func(_ *Command, args []string) { - parentRunArgs = strings.Join(args, " ") + validateHook(args, "parent Run") }, PostRun: func(_ *Command, args []string) { - parentPostArgs = strings.Join(args, " ") + validateHook(args, "parent PostRun") }, PersistentPostRun: func(_ *Command, args []string) { - parentPersPostArgs = strings.Join(args, " ") + validateHook(args, "parent PersistentPostRun") }, } childCmd := &Command{ Use: "child", PersistentPreRun: func(_ *Command, args []string) { - childPersPreArgs = strings.Join(args, " ") + validateHook(args, "child PersistentPreRun") }, PreRun: func(_ *Command, args []string) { - childPreArgs = strings.Join(args, " ") + validateHook(args, "child PreRun") }, Run: func(_ *Command, args []string) { - childRunArgs = strings.Join(args, " ") + validateHook(args, "child Run") }, PostRun: func(_ *Command, args []string) { - childPostArgs = strings.Join(args, " ") + validateHook(args, "child PostRun") }, PersistentPostRun: func(_ *Command, args []string) { - childPersPostArgs = strings.Join(args, " ") + validateHook(args, "child PersistentPostRun") }, } parentCmd.AddCommand(childCmd) @@ -1593,41 +1609,13 @@ func TestPersistentHooks(t *testing.T) { t.Errorf("Unexpected error: %v", err) } - for _, v := range []struct { - name string - got string - }{ - // TODO: currently PersistentPreRun* defined in parent does not - // run if the matching child subcommand has PersistentPreRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - {"parentPersPreArgs", parentPersPreArgs}, - {"parentPreArgs", parentPreArgs}, - {"parentRunArgs", parentRunArgs}, - {"parentPostArgs", parentPostArgs}, - // TODO: currently PersistentPostRun* defined in parent does not - // run if the matching child subcommand has PersistentPostRun. - // If the behavior changes (https://github.com/spf13/cobra/issues/252) - // this test must be fixed. - {"parentPersPostArgs", parentPersPostArgs}, - } { - if v.got != "" { - t.Errorf("Expected blank %s, got %q", v.name, v.got) - } - } - - for _, v := range []struct { - name string - got string - }{ - {"childPersPreArgs", childPersPreArgs}, - {"childPreArgs", childPreArgs}, - {"childRunArgs", childRunArgs}, - {"childPostArgs", childPostArgs}, - {"childPersPostArgs", childPersPostArgs}, - } { - if v.got != onetwo { - t.Errorf("Expected %s %q, got %q", v.name, onetwo, v.got) + for idx, exp := range expectedHookRunOrder { + if len(hookRunOrder) > idx { + if act := hookRunOrder[idx]; act != exp { + t.Errorf("Expected %q at %d, got %q", exp, idx, act) + } + } else { + t.Errorf("Expected %q at %d, got nothing", exp, idx) } } }