Skip to content

Commit

Permalink
[scan-9] Update enumeration logic (#3626)
Browse files Browse the repository at this point in the history
* renaming to enumeration

* update enumeration

* comments

* remove commented out func

---------

Co-authored-by: Miccah Castorina <m.castorina93@gmail.com>
  • Loading branch information
0x1 and mcastorina authored Nov 25, 2024
1 parent f119adc commit 1276d26
Show file tree
Hide file tree
Showing 18 changed files with 171 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pkg/engine/circleci.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ func (e *Engine) ScanCircleCI(ctx context.Context, token string) (sources.JobPro
if err := circleSource.Init(ctx, "trufflehog - Circle CI", jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, circleSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, circleSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ func (e *Engine) ScanDocker(ctx context.Context, c sources.DockerConfig) (source
if err := dockerSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, dockerSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, dockerSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ func (e *Engine) ScanElasticsearch(ctx context.Context, c sources.ElasticsearchC
if err := elasticsearchSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, elasticsearchSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, elasticsearchSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ func (e *Engine) ScanFileSystem(ctx context.Context, c sources.FilesystemConfig)
if err := fileSystemSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, fileSystemSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, fileSystemSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (e *Engine) ScanGCS(ctx context.Context, c sources.GCSConfig) (sources.JobP
if err := gcsSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, int(c.Concurrency)); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, gcsSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gcsSource)
}

func isAuthValid(ctx context.Context, c sources.GCSConfig, connection *sourcespb.GCS) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) (sources.JobP
return sources.JobProgressRef{}, err
}

return e.sourceManager.Run(ctx, sourceName, gitSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gitSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) (source
return sources.JobProgressRef{}, err
}
githubSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, githubSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, githubSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/github_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ func (e *Engine) ScanGitHubExperimental(ctx context.Context, c sources.GitHubExp
return sources.JobProgressRef{}, err
}
githubExperimentalSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, githubExperimentalSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, githubExperimentalSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) (source
return sources.JobProgressRef{}, err
}
gitlabSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, gitlabSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gitlabSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ func (e *Engine) ScanHuggingface(ctx context.Context, c HuggingfaceConfig) (sour
if err := huggingfaceSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, huggingfaceSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, huggingfaceSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/jenkins.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ func (e *Engine) ScanJenkins(ctx context.Context, jenkinsConfig JenkinsConfig) (
if err := jenkinsSource.Init(ctx, "trufflehog - Jenkins", jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, jenkinsSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, jenkinsSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/postman.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ func (e *Engine) ScanPostman(ctx context.Context, c sources.PostmanConfig) (sour
if err := postmanSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, postmanSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, postmanSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ func (e *Engine) ScanS3(ctx context.Context, c sources.S3Config) (sources.JobPro
if err := s3Source.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, s3Source)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, s3Source)
}
2 changes: 1 addition & 1 deletion pkg/engine/syslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ func (e *Engine) ScanSyslog(ctx context.Context, c sources.SyslogConfig) (source
}
syslogSource.InjectConnection(connection)

return e.sourceManager.Run(ctx, sourceName, syslogSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, syslogSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/travisci.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ func (e *Engine) ScanTravisCI(ctx context.Context, token string) (sources.JobPro
if err := travisSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, travisSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, travisSource)
}
122 changes: 120 additions & 2 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ func (s *SourceManager) GetIDs(ctx context.Context, sourceName string, kind sour
return s.api.GetIDs(ctx, sourceName, kind)
}

// Run blocks until a resource is available to run the source, then
// EnumerateAndScan blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is stored and accessible via the
// JobProgressRef as it becomes available.
func (s *SourceManager) Run(ctx context.Context, sourceName string, source Source, targets ...ChunkingTarget) (JobProgressRef, error) {
func (s *SourceManager) EnumerateAndScan(ctx context.Context, sourceName string, source Source, targets ...ChunkingTarget) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
Expand Down Expand Up @@ -169,6 +169,54 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
return progress.Ref(), nil
}

func (s *SourceManager) Enumerate(ctx context.Context, sourceName string, source Source, reporter UnitReporter) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
return JobProgressRef{
SourceName: sourceName,
SourceID: sourceID,
JobID: jobID,
}, err
}

// Create a JobProgress object for tracking progress.
sem := s.sem
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
if err := sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
}

// Wrap the passed in reporter so we update the progress information.
reporter = baseUnitReporter{
child: reporter,
progress: progress,
}

s.wg.Add(1)
go func() {
// Call Finish after the semaphore has been released.
defer progress.Finish()
defer sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel(nil)
if err := s.enumerate(ctx, source, progress, reporter); err != nil {
select {
case s.firstErr <- err:
default:
}
}
}()
return progress.Ref(), nil
}

// Chunks returns the read only channel of all the chunks produced by all of
// the sources managed by this manager.
func (s *SourceManager) Chunks() <-chan *Chunk {
Expand Down Expand Up @@ -286,6 +334,75 @@ func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgr
return s.runWithoutUnits(ctx, source, report, targets...)
}

// enumerate is a helper method to enumerate a Source.
func (s *SourceManager) enumerate(ctx context.Context, source Source, report *JobProgress, reporter UnitReporter) error {
report.Start(time.Now())
defer func() { report.End(time.Now()) }()

defer func() {
if err := context.Cause(ctx); err != nil {
report.ReportError(Fatal{err})
}
}()

report.TrackProgress(source.GetProgress())
if ctx.Value("job_id") == "" {
ctx = context.WithValue(ctx, "job_id", report.JobID)
}
if ctx.Value("source_id") == "" {
ctx = context.WithValue(ctx, "source_id", report.SourceID)
}
if ctx.Value("source_name") == "" {
ctx = context.WithValue(ctx, "source_name", report.SourceName)
}
if ctx.Value("source_type") == "" {
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}

// Check for the preferred method of tracking source units.
canUseSourceUnits := s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
"with_units", true)
return s.enumerateWithUnits(ctx, enumChunker, report, reporter)
}
return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())
}

// enumerateWithUnits is a helper method to enumerate a Source that is also a
// SourceUnitEnumerator. This allows better introspection of what is getting
// enumerated and any errors encountered.
func (s *SourceManager) enumerateWithUnits(ctx context.Context, source SourceUnitEnumerator, report *JobProgress, reporter UnitReporter) error {
// Create a function that will save the first error encountered (if
// any) and discard the rest.
fatalErr := make(chan error, 1)
catchFirstFatal := func(err error) {
select {
case fatalErr <- err:
default:
}
}

// Produce units.
func() {
// TODO: Catch panics and add to report.
report.StartEnumerating(time.Now())
defer func() { report.EndEnumerating(time.Now()) }()
ctx.Logger().V(2).Info("enumerating source with units")
if err := source.Enumerate(ctx, reporter); err != nil {
report.ReportError(Fatal{err})
catchFirstFatal(Fatal{err})
}
}()

select {
case err := <-fatalErr:
return err
default:
return nil
}
}

// runWithoutUnits is a helper method to run a Source. It has coarse-grained
// job reporting.
func (s *SourceManager) runWithoutUnits(ctx context.Context, source Source, report *JobProgress, targets ...ChunkingTarget) error {
Expand All @@ -302,6 +419,7 @@ func (s *SourceManager) runWithoutUnits(ctx context.Context, source Source, repo
s.outputChunks <- chunk
}
}()

// Don't return from this function until the goroutine has finished
// outputting chunks to the downstream channel. Closing the channel
// will stop the goroutine, so that needs to happen first in the defer
Expand Down
24 changes: 12 additions & 12 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestSourceManagerRun(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
for i := 0; i < 3; i++ {
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
<-ref.Done()
assert.NoError(t, err)
assert.NoError(t, ref.Snapshot().FatalError())
Expand All @@ -132,7 +132,7 @@ func TestSourceManagerWait(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
// Asynchronously run the source.
_, err = mgr.Run(context.Background(), "dummy", source)
_, err = mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
// Read the 1 chunk we're expecting so Waiting completes.
<-mgr.Chunks()
Expand All @@ -141,15 +141,15 @@ func TestSourceManagerWait(t *testing.T) {
// Run should return an error now.
_, err = buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
_, err = mgr.Run(context.Background(), "dummy", source)
_, err = mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.Error(t, err)
}

func TestSourceManagerError(t *testing.T) {
mgr := NewManager()
source, err := buildDummy(errorChunker{fmt.Errorf("oops")})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
Expand All @@ -165,7 +165,7 @@ func TestSourceManagerReport(t *testing.T) {
mgr := NewManager(opts...)
source, err := buildDummy(&counterChunker{count: 4})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Equal(t, 0, len(ref.Snapshot().Errors))
Expand Down Expand Up @@ -230,7 +230,7 @@ func TestSourceManagerNonFatalError(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
report := ref.Snapshot()
Expand All @@ -247,7 +247,7 @@ func TestSourceManagerContextCancelled(t *testing.T) {
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
ref, err := mgr.Run(ctx, "dummy", source)
ref, err := mgr.EnumerateAndScan(ctx, "dummy", source)
assert.NoError(t, err)

cancel()
Expand Down Expand Up @@ -291,7 +291,7 @@ func TestSourceManagerCancelRun(t *testing.T) {
}})
assert.NoError(t, err)

ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

cancelErr := fmt.Errorf("abort! abort!")
Expand All @@ -313,7 +313,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, 1337, mgr.AvailableCapacity())
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

<-start // Wait for start signal.
Expand All @@ -338,7 +338,7 @@ func TestSourceManagerUnitHook(t *testing.T) {
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
Expand Down Expand Up @@ -399,7 +399,7 @@ func TestSourceManagerUnitHookBackPressure(t *testing.T) {
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

var metrics []UnitMetrics
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 5})
assert.NoError(t, err)

ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
Expand Down
24 changes: 24 additions & 0 deletions pkg/sources/sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ type SourceUnitEnumerator interface {
Enumerate(ctx context.Context, reporter UnitReporter) error
}

// BaseUnitReporter is a helper struct that implements the UnitReporter interface
// and includes a JobProgress reference.
type baseUnitReporter struct {
child UnitReporter
progress *JobProgress
}

func (b baseUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
b.progress.ReportUnit(unit)
if b.child != nil {
return b.child.UnitOk(ctx, unit)
}
return nil
}

func (b baseUnitReporter) UnitErr(ctx context.Context, err error) error {
b.progress.ReportError(err)
if b.child != nil {
return b.child.UnitErr(ctx, err)
}
return nil
}


// UnitReporter defines the interface a source will use to report whether a
// unit was found during enumeration. Either method may be called any number of
// times. Implementors of this interface should allow for concurrent calls.
Expand Down

0 comments on commit 1276d26

Please sign in to comment.