diff --git a/expect.go b/expect.go index d340777..ee26707 100644 --- a/expect.go +++ b/expect.go @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr return fmt.Errorf("could not create expect options: %w", err) } - cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...) + cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...) if err != nil { return fmt.Errorf("could not add consumer: %w", err) } @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp select { case <-time.After(timeoutV): return fmt.Errorf("after %s: %w", timeoutV, TimeoutError) - case err := <-waitChan(tt.cmd.Wait): - if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) { - return fmt.Errorf("cmd wait failed: %w", err) + case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{} + if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) { + return fmt.Errorf("cmd wait failed: %w", state.Err) } - if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil { + if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil { return err } } diff --git a/helpers.go b/helpers.go index 3144ad2..9092a75 100644 --- a/helpers.go +++ b/helpers.go @@ -24,11 +24,11 @@ type cmdExit struct { } // waitForCmdExit turns process.wait() into a channel so that it can be used within a select{} statement -func waitForCmdExit(cmd *exec.Cmd) chan cmdExit { - exit := make(chan cmdExit, 1) +func waitForCmdExit(cmd *exec.Cmd) chan *cmdExit { + exit := make(chan *cmdExit, 1) go func() { err := cmd.Wait() - exit <- cmdExit{ProcessState: cmd.ProcessState, Err: err} + exit <- &cmdExit{ProcessState: cmd.ProcessState, Err: err} }() return exit } @@ -36,7 +36,7 @@ func waitForCmdExit(cmd *exec.Cmd) chan cmdExit { func waitChan[T any](wait func() T) chan T { done := make(chan T) go func() { - wait() + done <- wait() close(done) }() return done diff --git a/outputconsumer.go b/outputconsumer.go index 1665e4b..5e51c26 100644 --- a/outputconsumer.go +++ b/outputconsumer.go @@ -15,6 +15,7 @@ type outputConsumer struct { opts *OutputConsumerOpts isalive bool mutex *sync.Mutex + tt *TermTest } type OutputConsumerOpts struct { @@ -36,7 +37,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) { } } -func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer { +func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer { oc := &outputConsumer{ consume: consume, opts: &OutputConsumerOpts{ @@ -46,6 +47,7 @@ func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer { waiter: make(chan error, 1), isalive: true, mutex: &sync.Mutex{}, + tt: tt, } for _, optSetter := range opts { @@ -101,5 +103,11 @@ func (e *outputConsumer) wait() error { e.mutex.Lock() e.opts.Logger.Println("Encountered timeout") return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError) + case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{} + e.mutex.Lock() + if state.Err != nil { + e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error()) + } + return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode()) } } diff --git a/outputproducer.go b/outputproducer.go index 34ae967..8c68b7a 100644 --- a/outputproducer.go +++ b/outputproducer.go @@ -238,12 +238,12 @@ func (o *outputProducer) flushConsumers() error { return nil } -func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) { +func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) { o.opts.Logger.Printf("adding consumer") defer o.opts.Logger.Printf("added consumer") opts = append(opts, OptConsInherit(o.opts)) - listener := newOutputConsumer(consume, opts...) + listener := newOutputConsumer(tt, consume, opts...) o.consumers = append(o.consumers, listener) if err := o.flushConsumers(); err != nil { diff --git a/termtest.go b/termtest.go index 7677bce..d363ad0 100644 --- a/termtest.go +++ b/termtest.go @@ -24,6 +24,7 @@ type TermTest struct { outputProducer *outputProducer listenError chan error opts *Opts + exited *cmdExit } type ErrorHandler func(*TermTest, error) error @@ -50,6 +51,9 @@ type SetOpt func(o *Opts) error const DefaultCols = 140 const DefaultRows = 10 +var processExitPollInterval = 10 * time.Millisecond +var processExitExtraWait = 500 * time.Millisecond + func NewOpts() *Opts { return &Opts{ Logger: VoidLogger, @@ -234,6 +238,10 @@ func (tt *TermTest) start() (rerr error) { }() wg.Wait() + go func() { + tt.exited = <-waitForCmdExit(tt.cmd) + }() + return nil } @@ -316,6 +324,28 @@ func (tt *TermTest) SendCtrlC() { tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C } +// Exited returns a channel that sends the given termtest's command cmdExit info when available. +// This can be used within a select{} statement. +// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow +// switch cases with output consumers to handle unprocessed stdout. If there are no such cases +// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false. +func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit { + return waitChan(func() *cmdExit { + ticker := time.NewTicker(processExitPollInterval) + for { + select { + case <-ticker.C: + if tt.exited != nil { + if waitExtra { // allow sibling output consumer cases to handle their output + time.Sleep(processExitExtraWait) + } + return tt.exited + } + } + } + }) +} + func (tt *TermTest) errorHandler(rerr *error) { err := *rerr if err == nil {