diff --git a/script.go b/script.go index 76c7185..cff9351 100644 --- a/script.go +++ b/script.go @@ -112,18 +112,6 @@ func Get(url string) *Pipe { return NewPipe().Get(url) } -// getStderr obtains the stderr writer on the pipe. This field -// is protected by a mutex since stderr is accessed inside a -// goroutine from [Pipe.Exec] and [Pipe.ExecForEach]. -func (p *Pipe) getStderr() io.Writer { - if p.mu == nil { // uninitialised pipe - return nil - } - p.mu.Lock() - defer p.mu.Unlock() - return p.stderr -} - // IfExists tests whether path exists, and creates a pipe whose error status // reflects the result. If the file doesn't exist, the pipe's error status will // be set, and if the file does exist, the pipe will have no error status. This @@ -427,7 +415,7 @@ func (p *Pipe) Exec(cmdLine string) *Pipe { cmd.Stdin = r cmd.Stdout = w cmd.Stderr = w - pipeStderr := p.getStderr() + pipeStderr := p.stdErr() if pipeStderr != nil { cmd.Stderr = pipeStderr } @@ -468,7 +456,7 @@ func (p *Pipe) ExecForEach(cmdLine string) *Pipe { cmd := exec.Command(args[0], args[1:]...) cmd.Stdout = w cmd.Stderr = w - pipeStderr := p.getStderr() + pipeStderr := p.stdErr() if pipeStderr != nil { cmd.Stderr = pipeStderr } @@ -806,18 +794,6 @@ func (p *Pipe) SetError(err error) { p.err = err } -// setStderr sets the stderr writer on the pipe. This field -// is protected by a mutex since stderr is accessed inside a -// goroutine from [Pipe.Exec] and [Pipe.ExecForEach]. -func (p *Pipe) setStderr(stderr io.Writer) { - if p.mu == nil { // uninitialised pipe - return - } - p.mu.Lock() - defer p.mu.Unlock() - p.stderr = stderr -} - // SHA256Sum returns the hex-encoded SHA-256 hash of the entire contents of the // pipe, or an error. func (p *Pipe) SHA256Sum() (string, error) { @@ -866,6 +842,18 @@ func (p *Pipe) Slice() ([]string, error) { return result, p.Error() } +// stdErr returns the pipe's configured standard error writer for commands run +// via [Pipe.Exec] and [Pipe.ExecForEach]. The default is nil, which means that +// error output will go to the pipe. +func (p *Pipe) stdErr() io.Writer { + if p.mu == nil { // uninitialised pipe + return nil + } + p.mu.Lock() + defer p.mu.Unlock() + return p.stderr +} + // Stdout copies the pipe's contents to its configured standard output (using // [Pipe.WithStdout]), or to [os.Stdout] otherwise, and returns the number of // bytes successfully written, together with any error. @@ -940,11 +928,12 @@ func (p *Pipe) WithReader(r io.Reader) *Pipe { return p } -// WithStderr redirects the standard error output for commands run via -// [Pipe.Exec] or [Pipe.ExecForEach] to the writer w, instead of going to the -// pipe as it normally would. +// WithStderr sets the standard error output for [Pipe.Exec] or +// [Pipe.ExecForEach] commands to w, instead of the pipe. func (p *Pipe) WithStderr(w io.Writer) *Pipe { - p.setStderr(w) + p.mu.Lock() + defer p.mu.Unlock() + p.stderr = w return p } diff --git a/script_test.go b/script_test.go index bff16e3..e006e7b 100644 --- a/script_test.go +++ b/script_test.go @@ -1971,14 +1971,9 @@ func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) { } } -// TestWithStdErrAfterExec_DoesNotResultInRaceCondition is a regression test -// that was added to test against a race condition for [Pipe.stderr]. -func TestWithStdErrAfterExec_DoesNotResultInRaceCondition(t *testing.T) { +func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) { t.Parallel() - stdOut := new(bytes.Buffer) - stdErr := new(bytes.Buffer) - - _, err := script.Exec("echo").WithStdout(stdOut).WithStderr(stdErr).Stdout() + err := script.Exec("echo").WithStderr(nil).Wait() if err != nil { t.Fatal(err) }