diff --git a/node/pkg/common/scissors.go b/node/pkg/common/scissors.go index 8ea40e68b6..aab8b61feb 100644 --- a/node/pkg/common/scissors.go +++ b/node/pkg/common/scissors.go @@ -24,37 +24,7 @@ var ( // Start a go routine with recovering from any panic by sending an error to a error channel func RunWithScissors(ctx context.Context, errC chan error, name string, runnable supervisor.Runnable) { - ScissorsErrorsCaught.WithLabelValues(name).Add(0) - ScissorsPanicsCaught.WithLabelValues(name).Add(0) - go func() { - defer func() { - if r := recover(); r != nil { - var err error - switch x := r.(type) { - case error: - err = fmt.Errorf("%s: %w", name, x) - default: - err = fmt.Errorf("%s: %v", name, x) - } - // We don't want this to hang if the listener has already gone away. - select { - case errC <- err: - default: - } - ScissorsPanicsCaught.WithLabelValues(name).Inc() - - } - }() - err := runnable(ctx) - if err != nil { - // We don't want this to hang if the listener has already gone away. - select { - case errC <- err: - default: - } - ScissorsErrorsCaught.WithLabelValues(name).Inc() - } - }() + StartRunnable(ctx, errC, true, name, runnable) } func WrapWithScissors(runnable supervisor.Runnable, name string) supervisor.Runnable { @@ -76,3 +46,48 @@ func WrapWithScissors(runnable supervisor.Runnable, name string) supervisor.Runn return runnable(ctx) } } + +// StartRunnable starts a go routine with the ability to recover from errors by publishing them to an error channel. If catchPanics is true, +// it will also catch panics and publish the panic message to the error channel. If catchPanics is false, the panic will be propagated upward. +func StartRunnable(ctx context.Context, errC chan error, catchPanics bool, name string, runnable supervisor.Runnable) { + ScissorsErrorsCaught.WithLabelValues(name).Add(0) + if catchPanics { + ScissorsPanicsCaught.WithLabelValues(name).Add(0) + } + go func() { + if catchPanics { + defer func() { + if r := recover(); r != nil { + var err error + switch x := r.(type) { + case error: + err = fmt.Errorf("%s: %w", name, x) + default: + err = fmt.Errorf("%s: %v", name, x) + } + // We don't want this to hang if the listener has already gone away. + select { + case errC <- err: + default: + } + ScissorsPanicsCaught.WithLabelValues(name).Inc() + + } + }() + } + startRunnable(ctx, errC, name, runnable) + }() +} + +// startRunnable is used by StartRunnable. It is a separate function so we can call it directly from tests. +func startRunnable(ctx context.Context, errC chan error, name string, runnable supervisor.Runnable) { + err := runnable(ctx) + if err != nil { + // We don't want this to hang if the listener has already gone away. + select { + case errC <- err: + default: + } + ScissorsErrorsCaught.WithLabelValues(name).Inc() + } +} diff --git a/node/pkg/common/scissors_test.go b/node/pkg/common/scissors_test.go index 7912eca907..96b7a2f058 100644 --- a/node/pkg/common/scissors_test.go +++ b/node/pkg/common/scissors_test.go @@ -203,3 +203,135 @@ func TestRunWithScissorsErrorDoesNotBlockWhenNoListener(t *testing.T) { assert.Equal(t, 1.0, getCounterValue(ScissorsErrorsCaught, "TestRunWithScissorsErrorDoesNotBlockWhenNoListener")) assert.Equal(t, 0.0, getCounterValue(ScissorsPanicsCaught, "TestRunWithScissorsErrorDoesNotBlockWhenNoListener")) } + +func TestStartRunnable_CleanExit(t *testing.T) { + ctx := context.Background() + errC := make(chan error) + + itRan := make(chan bool, 1) + StartRunnable(ctx, errC, true, "TestStartRunnable_CleanExit", func(ctx context.Context) error { + itRan <- true + return nil + }) + + shouldHaveRun := <-itRan + require.Equal(t, true, shouldHaveRun) + + // Need to wait a bit to make sure the scissors code completes without hanging. + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, 0.0, getCounterValue(ScissorsErrorsCaught, "TestStartRunnable_CleanExit")) + assert.Equal(t, 0.0, getCounterValue(ScissorsPanicsCaught, "TestStartRunnable_CleanExit")) +} + +func TestStartRunnable_OnError(t *testing.T) { + ctx := context.Background() + errC := make(chan error) + + itRan := make(chan bool, 1) + StartRunnable(ctx, errC, true, "TestStartRunnable_OnError", func(ctx context.Context) error { + itRan <- true + return fmt.Errorf("Some random error") + }) + + var err error + select { + case <-ctx.Done(): + break + case err = <-errC: + break + } + + shouldHaveRun := <-itRan + require.Equal(t, true, shouldHaveRun) + assert.Error(t, err) + assert.Equal(t, "Some random error", err.Error()) + assert.Equal(t, 1.0, getCounterValue(ScissorsErrorsCaught, "TestStartRunnable_OnError")) + assert.Equal(t, 0.0, getCounterValue(ScissorsPanicsCaught, "TestStartRunnable_OnError")) +} + +func TestStartRunnable_DontCatchPanics_OnPanic(t *testing.T) { + ctx := context.Background() + errC := make(chan error) + + itRan := make(chan bool, 1) + itPanicked := make(chan bool, 1) + + // We can't use StartRunnable() because we cannot test for a panic in another go routine. + // This verifies that startRunnable() lets the panic through so it gets caught here, allowing us to test for it. + func() { + defer func() { + if r := recover(); r != nil { + itPanicked <- true + } + itRan <- true + }() + + startRunnable(ctx, errC, "TestStartRunnable_DontCatchPanics_OnPanic", func(ctx context.Context) error { + panic("Some random panic") + }) + }() + + var shouldHaveRun bool + select { + case <-ctx.Done(): + break + case shouldHaveRun = <-itRan: + break + } + + require.Equal(t, true, shouldHaveRun) + + require.Equal(t, 1, len(itPanicked)) + shouldHavePanicked := <-itPanicked + require.Equal(t, true, shouldHavePanicked) + + assert.Equal(t, 0.0, getCounterValue(ScissorsErrorsCaught, "TestStartRunnable_DontCatchPanics_OnPanic")) + assert.Equal(t, 0.0, getCounterValue(ScissorsPanicsCaught, "TestStartRunnable_DontCatchPanics_OnPanic")) +} + +func TestStartRunnable_CatchPanics_OnPanic(t *testing.T) { + ctx := context.Background() + errC := make(chan error) + + itRan := make(chan bool, 1) + StartRunnable(ctx, errC, true, "TestStartRunnable_CatchPanics_OnPanic", func(ctx context.Context) error { + itRan <- true + panic("Some random panic") + }) + + var err error + select { + case <-ctx.Done(): + break + case err = <-errC: + break + } + + shouldHaveRun := <-itRan + require.Equal(t, true, shouldHaveRun) + assert.Error(t, err) + assert.Equal(t, "TestStartRunnable_CatchPanics_OnPanic: Some random panic", err.Error()) + assert.Equal(t, 0.0, getCounterValue(ScissorsErrorsCaught, "TestStartRunnable_CatchPanics_OnPanic")) + assert.Equal(t, 1.0, getCounterValue(ScissorsPanicsCaught, "TestStartRunnable_CatchPanics_OnPanic")) +} + +func TestStartRunnable_DoesNotBlockWhenNoListener(t *testing.T) { + ctx := context.Background() + errC := make(chan error) + + itRan := make(chan bool, 1) + StartRunnable(ctx, errC, true, "TestStartRunnable_DoesNotBlockWhenNoListener", func(ctx context.Context) error { + itRan <- true + panic("Some random panic") + }) + + shouldHaveRun := <-itRan + require.Equal(t, true, shouldHaveRun) + + // Need to wait a bit to make sure the scissors code completes without hanging. + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, 0.0, getCounterValue(ScissorsErrorsCaught, "TestStartRunnable_DoesNotBlockWhenNoListener")) + assert.Equal(t, 1.0, getCounterValue(ScissorsPanicsCaught, "TestStartRunnable_DoesNotBlockWhenNoListener")) +}