diff --git a/internal/ensure/requiredversion.go b/internal/ensure/requiredversion.go index 50055aec..a540b990 100644 --- a/internal/ensure/requiredversion.go +++ b/internal/ensure/requiredversion.go @@ -14,6 +14,7 @@ import ( "runtime" "syscall" + "github.com/DataDog/orchestrion/internal/goenv" "github.com/DataDog/orchestrion/internal/log" "github.com/DataDog/orchestrion/internal/version" "golang.org/x/tools/go/packages" @@ -128,8 +129,13 @@ func requiredVersion( // working directory is used. The version may be blank if a replace directive is in effect; in which // case the path value may indicate the location of the source code that is being used instead. func goModVersion(dir string) (moduleVersion string, moduleDir string, err error) { + gomod, err := goenv.GOMOD(dir) + if err != nil { + return "", "", err + } + cfg := &packages.Config{ - Dir: dir, + Dir: filepath.Dir(gomod), Mode: packages.NeedModule, Logf: func(format string, args ...any) { log.Tracef(format+"\n", args...) }, } diff --git a/internal/ensure/requiredversion_test.go b/internal/ensure/requiredversion_test.go index 3f196e8e..91a4be24 100644 --- a/internal/ensure/requiredversion_test.go +++ b/internal/ensure/requiredversion_test.go @@ -87,9 +87,7 @@ func TestGoModVersion(t *testing.T) { } t.Run("no-go-mod", func(t *testing.T) { - tmp, err := os.MkdirTemp("", "ensure-*") - require.NoError(t, err, "failed to create temporary directory") - defer os.RemoveAll(tmp) + tmp := t.TempDir() os.WriteFile(filepath.Join(tmp, "main.go"), []byte(` package main @@ -98,9 +96,9 @@ func TestGoModVersion(t *testing.T) { `), 0o644) require.NotPanics(t, func() { - _, _, err = goModVersion(tmp) + _, _, err := goModVersion(tmp) + require.ErrorIs(t, err, goenv.ErrNoGoMod) }) - require.ErrorContains(t, err, "go.mod file not found in current directory") }) } diff --git a/internal/goenv/goenv.go b/internal/goenv/goenv.go index a06fb72d..fd2f35dd 100644 --- a/internal/goenv/goenv.go +++ b/internal/goenv/goenv.go @@ -16,15 +16,13 @@ import ( var ( // ErrNoGoMod is returned when no GOMOD value could be identified. - ErrNoGoMod = errors.New("`go mod GOMOD` returned a blank string") + ErrNoGoMod = errors.New("`go env GOMOD` returned a blank string") ) -// GOMOD returns the current GOMOD environment variable (possibly from running `go env GOMOD`). -func GOMOD() (string, error) { - if goMod := os.Getenv("GOMOD"); goMod != "" { - return goMod, nil - } +// GOMOD returns the current GOMOD environment variable (from running `go env GOMOD`). +func GOMOD(dir string) (string, error) { cmd := exec.Command("go", "env", "GOMOD") + cmd.Dir = dir var stdout bytes.Buffer cmd.Stdout = &stdout if err := cmd.Run(); err != nil { diff --git a/internal/goenv/goenv_test.go b/internal/goenv/goenv_test.go index 0379c0d3..c33d087b 100644 --- a/internal/goenv/goenv_test.go +++ b/internal/goenv/goenv_test.go @@ -6,7 +6,6 @@ package goenv import ( - "os" "testing" "github.com/stretchr/testify/require" @@ -14,31 +13,14 @@ import ( func TestGOMOD(t *testing.T) { t.Run("without GOMOD environment variable", func(t *testing.T) { - t.Setenv("GOMOD", "") - - gomod, err := GOMOD() + gomod, err := GOMOD("") require.NoError(t, err) require.NotEmpty(t, gomod) }) t.Run("no GOMOD can be found", func(t *testing.T) { - t.Setenv("GOMOD", "") - - wd, _ := os.Getwd() - defer os.Chdir(wd) - os.Chdir(os.TempDir()) - - val, err := GOMOD() + val, err := GOMOD(t.TempDir()) require.Empty(t, val) require.ErrorIs(t, err, ErrNoGoMod) }) - - t.Run("with GOMOD environment variable", func(t *testing.T) { - expected := "/fake/path/to/go.mod" - t.Setenv("GOMOD", expected) - - gomod, err := GOMOD() - require.NoError(t, err) - require.EqualValues(t, expected, gomod) - }) } diff --git a/internal/pin/pin.go b/internal/pin/pin.go index ae7c5edd..18173aa8 100644 --- a/internal/pin/pin.go +++ b/internal/pin/pin.go @@ -91,7 +91,7 @@ func defaultOrchestrionToolGo() *dst.File { // PinOrchestrion applies or update the orchestrion pin file in the current // working directory, according to the supplied [Options]. func PinOrchestrion(opts Options) error { - goMod, err := goenv.GOMOD() + goMod, err := goenv.GOMOD("") if err != nil { return fmt.Errorf("getting GOMOD: %w", err) }