Skip to content

Commit

Permalink
VAULT-25848 replace mholt/archiver with native go calls (#27228) (#27250
Browse files Browse the repository at this point in the history
)

* VAULT-25848 update product code to remove mholt/archiver dependency

* VAULT-25848 replace tests, still WIP while I figure out if there's a bug caught by TestDebugCommand_PartialPermissions

* VAULT-25848 actually remove the dep

* VAULT-25848 add headers for directories, improve test

* Comment cleanup

* Typo

* Use %w

* Typo
  • Loading branch information
VioletHynes authored Jun 3, 2024
1 parent 6d31e2a commit 2c9fcfd
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 188 deletions.
119 changes: 103 additions & 16 deletions command/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package command

import (
"archive/tar"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/url"
"os"
"path/filepath"
Expand All @@ -26,7 +28,6 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/version"
"github.com/mholt/archiver/v3"
"github.com/oklog/run"
"github.com/posener/complete"
)
Expand Down Expand Up @@ -374,7 +375,7 @@ func (c *DebugCommand) generateIndex() error {
}

// Write out file
if err := ioutil.WriteFile(filepath.Join(c.flagOutput, "index.json"), bytes, 0o600); err != nil {
if err := os.WriteFile(filepath.Join(c.flagOutput, "index.json"), bytes, 0o600); err != nil {
return fmt.Errorf("error generating index file; %s", err)
}

Expand Down Expand Up @@ -777,7 +778,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
return
}

err = ioutil.WriteFile(filepath.Join(dirName, target+".prof"), data, 0o600)
err = os.WriteFile(filepath.Join(dirName, target+".prof"), data, 0o600)
if err != nil {
c.captureError("pprof."+target, err)
}
Expand All @@ -795,13 +796,13 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
return
}

err = ioutil.WriteFile(filepath.Join(dirName, "goroutines.txt"), data, 0o600)
err = os.WriteFile(filepath.Join(dirName, "goroutines.txt"), data, 0o600)
if err != nil {
c.captureError("pprof.goroutines-text", err)
}
}()

// If the our remaining duration is less than the interval value
// If our remaining duration is less than the interval value
// skip profile and trace.
runDuration := currentTimestamp.Sub(startTime)
if (c.flagDuration+debugDurationGrace)-runDuration < c.flagInterval {
Expand All @@ -819,7 +820,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
return
}

err = ioutil.WriteFile(filepath.Join(dirName, "profile.prof"), data, 0o600)
err = os.WriteFile(filepath.Join(dirName, "profile.prof"), data, 0o600)
if err != nil {
c.captureError("pprof.profile", err)
}
Expand All @@ -835,7 +836,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
return
}

err = ioutil.WriteFile(filepath.Join(dirName, "trace.out"), data, 0o600)
err = os.WriteFile(filepath.Join(dirName, "trace.out"), data, 0o600)
if err != nil {
c.captureError("pprof.trace", err)
}
Expand Down Expand Up @@ -971,7 +972,7 @@ func (c *DebugCommand) persistCollection(collection []map[string]interface{}, ou
if err != nil {
return err
}
if err := ioutil.WriteFile(filepath.Join(c.flagOutput, outFile), bytes, 0o600); err != nil {
if err := os.WriteFile(filepath.Join(c.flagOutput, outFile), bytes, 0o600); err != nil {
return err
}

Expand All @@ -983,14 +984,100 @@ func (c *DebugCommand) compress(dst string) error {
defer osutil.Umask(osutil.Umask(0o077))
}

tgz := archiver.NewTarGz()
if err := tgz.Archive([]string{c.flagOutput}, dst); err != nil {
return fmt.Errorf("failed to compress data: %s", err)
if err := archiveToTgz(c.flagOutput, dst); err != nil {
return fmt.Errorf("failed to compress data: %w", err)
}

// If everything is fine up to this point, remove original directory
if err := os.RemoveAll(c.flagOutput); err != nil {
return fmt.Errorf("failed to remove data directory: %s", err)
return fmt.Errorf("failed to remove data directory: %w", err)
}

return nil
}

// archiveToTgz compresses all the files in sourceDir to a
// a tarball at destination.
func archiveToTgz(sourceDir, destination string) error {
file, err := os.Create(destination)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()

gzipWriter := gzip.NewWriter(file)
defer gzipWriter.Close()

tarWriter := tar.NewWriter(gzipWriter)
defer tarWriter.Close()

err = filepath.Walk(sourceDir,
func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
return addFileToTar(sourceDir, filePath, tarWriter)
})

return err
}

// addFileToTar takes a file at filePath and adds it to the tar
// being written to by tarWriter, alongside its header.
// The tar header name will be relative. Example: If we're tarring
// a file in ~/a/b/c/foo/bar.json, the header name will be foo/bar.json
func addFileToTar(sourceDir, filePath string, tarWriter *tar.Writer) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file %q: %w", filePath, err)
}
defer file.Close()

stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file %q: %w", filePath, err)
}

var link string
mode := stat.Mode()
if mode&os.ModeSymlink != 0 {
if link, err = os.Readlink(filePath); err != nil {
return fmt.Errorf("failed to read symlink for file %q: %w", filePath, err)
}
}
tarHeader, err := tar.FileInfoHeader(stat, link)
if err != nil {
return fmt.Errorf("failed to create tar header for file %q: %w", filePath, err)
}

// The tar header name should be relative, so remove the sourceDir from it,
// but preserve the last directory name.
// Example: If we're tarring a file in ~/a/b/c/foo/bar.json
// The name should be foo/bar.json
sourceDirExceptLastDir := filepath.Dir(sourceDir)
headerName := strings.TrimPrefix(filepath.Clean(filePath), filepath.Clean(sourceDirExceptLastDir)+"/")

// Directories should end with a slash.
if stat.IsDir() && !strings.HasSuffix(headerName, "/") {
headerName += "/"
}
tarHeader.Name = headerName

err = tarWriter.WriteHeader(tarHeader)
if err != nil {
return fmt.Errorf("failed to write tar header for file %q: %w", filePath, err)
}

// If it's not a regular file (e.g. link or directory) we shouldn't
// copy the file. The body of a tar entry (i.e. what's done by the
// below io.Copy call) is only required for tar files of TypeReg.
if tarHeader.Typeflag != tar.TypeReg {
return nil
}

_, err = io.Copy(tarWriter, file)
if err != nil {
return fmt.Errorf("failed to copy file %q into tarball: %w", filePath, err)
}

return nil
Expand All @@ -1007,7 +1094,7 @@ func pprofTarget(ctx context.Context, client *api.Client, target string, params
}
defer resp.Body.Close()

data, err := ioutil.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
Expand All @@ -1027,7 +1114,7 @@ func pprofProfile(ctx context.Context, client *api.Client, duration time.Duratio
}
defer resp.Body.Close()

data, err := ioutil.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
Expand All @@ -1047,7 +1134,7 @@ func pprofTrace(ctx context.Context, client *api.Client, duration time.Duration)
}
defer resp.Body.Close()

data, err := ioutil.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 2c9fcfd

Please sign in to comment.