mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Initial implementation of JobReport with SourceManager usage (#1557)
* Initial implementation of JobReport with SourceManager usage * Limit concurrent units * Only save the last JobReport per handle
This commit is contained in:
parent
3897454dbb
commit
e391e89f3e
4 changed files with 339 additions and 19 deletions
37
pkg/sources/job_report.go
Normal file
37
pkg/sources/job_report.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package sources
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JobReport aggregates information about a run of a Source.
|
||||
type JobReport struct {
|
||||
SourceID int64
|
||||
JobID int64
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
TotalChunks uint64
|
||||
errors []error
|
||||
errorsLock sync.Mutex
|
||||
}
|
||||
|
||||
// AddError adds a non-nil error to the aggregate of errors encountered during
|
||||
// scanning.
|
||||
func (jr *JobReport) AddError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
jr.errorsLock.Lock()
|
||||
defer jr.errorsLock.Unlock()
|
||||
jr.errors = append(jr.errors, err)
|
||||
}
|
||||
|
||||
// Errors joins all aggregated errors into one. If there were no errors, nil is
|
||||
// returned. errors.Is can be used to check for specific errors.
|
||||
func (jr *JobReport) Errors() error {
|
||||
jr.errorsLock.Lock()
|
||||
defer jr.errorsLock.Unlock()
|
||||
return errors.Join(jr.errors...)
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
package sources
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
|
@ -25,12 +27,20 @@ type SourceManager struct {
|
|||
// Map of handle to source initializer.
|
||||
handles map[handle]SourceInitFunc
|
||||
handlesLock sync.Mutex
|
||||
// Map of handle to job reports.
|
||||
// TODO: Manage culling and flushing to the API.
|
||||
report map[handle]*JobReport
|
||||
reportLock sync.Mutex
|
||||
// Pool limiting the amount of concurrent sources running.
|
||||
pool errgroup.Group
|
||||
pool errgroup.Group
|
||||
concurrentUnits int
|
||||
// Run the sources using source unit enumeration / chunking if available.
|
||||
useSourceUnits bool
|
||||
// Downstream chunks channel to be scanned.
|
||||
outputChunks chan *Chunk
|
||||
// Set to true when Wait() returns.
|
||||
done bool
|
||||
// Set when Wait() returns.
|
||||
done bool
|
||||
doneErr error
|
||||
}
|
||||
|
||||
// apiClient is an interface for optionally communicating with an external API.
|
||||
|
@ -56,12 +66,25 @@ func WithBufferedOutput(size int) func(*SourceManager) {
|
|||
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
|
||||
}
|
||||
|
||||
// WithSourceUnits enables using source unit enumeration and chunking if the
|
||||
// source supports it.
|
||||
func WithSourceUnits() func(*SourceManager) {
|
||||
return func(mgr *SourceManager) { mgr.useSourceUnits = true }
|
||||
}
|
||||
|
||||
// WithConcurrentUnits limits the number of units to be scanned concurrently.
|
||||
// The default is unlimited.
|
||||
func WithConcurrentUnits(n int) func(*SourceManager) {
|
||||
return func(mgr *SourceManager) { mgr.concurrentUnits = n }
|
||||
}
|
||||
|
||||
// NewManager creates a new manager with the provided options.
|
||||
func NewManager(opts ...func(*SourceManager)) *SourceManager {
|
||||
mgr := SourceManager{
|
||||
// Default to the headless API. Can be overwritten by the WithAPI option.
|
||||
api: &headlessAPI{},
|
||||
handles: make(map[handle]SourceInitFunc),
|
||||
report: make(map[handle]*JobReport),
|
||||
outputChunks: make(chan *Chunk),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
|
@ -100,15 +123,24 @@ func (s *SourceManager) Run(ctx context.Context, handle handle) error {
|
|||
ch := make(chan error)
|
||||
s.pool.Go(func() error {
|
||||
defer common.Recover(ctx)
|
||||
// TODO: The manager should record these errors.
|
||||
ch <- s.run(ctx, handle)
|
||||
report, err := s.run(ctx, handle)
|
||||
if report != nil {
|
||||
s.reportLock.Lock()
|
||||
s.report[handle] = report
|
||||
s.reportLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
ch <- err
|
||||
return nil
|
||||
}
|
||||
ch <- report.Errors()
|
||||
return nil
|
||||
})
|
||||
return <-ch
|
||||
}
|
||||
|
||||
// ScheduleRun blocks until a resource is available to run the source, then
|
||||
// asynchronously runs it. Error information is lost in this case.
|
||||
// asynchronously runs it. Error information is stored and returned by Wait().
|
||||
func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
|
||||
// Do preflight checks before waiting on the pool.
|
||||
if err := s.preflightChecks(ctx, handle); err != nil {
|
||||
|
@ -116,8 +148,14 @@ func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
|
|||
}
|
||||
s.pool.Go(func() error {
|
||||
defer common.Recover(ctx)
|
||||
// TODO: The manager should record these errors.
|
||||
_ = s.run(ctx, handle)
|
||||
// The error is already saved in the report, so we can ignore
|
||||
// it here.
|
||||
report, _ := s.run(ctx, handle)
|
||||
if report != nil {
|
||||
s.reportLock.Lock()
|
||||
s.report[handle] = report
|
||||
s.reportLock.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// TODO: Maybe wait for a signal here that initialization was successful?
|
||||
|
@ -136,12 +174,34 @@ func (s *SourceManager) Chunks() <-chan *Chunk {
|
|||
func (s *SourceManager) Wait() error {
|
||||
// Check if the manager has been Waited.
|
||||
if s.done {
|
||||
return nil
|
||||
return s.doneErr
|
||||
}
|
||||
// TODO: Aggregate all errors from all sources.
|
||||
defer close(s.outputChunks)
|
||||
defer func() { s.done = true }()
|
||||
return s.pool.Wait()
|
||||
|
||||
// We are only using the errgroup for limiting concurrency.
|
||||
// TODO: Maybe switch to using a semaphore.Weighted.
|
||||
_ = s.pool.Wait()
|
||||
|
||||
// Aggregate all errors from all job reports.
|
||||
// TODO: This should probably only be the fatal errors. We'll also need
|
||||
// to rewrite this for when the reports start getting culled.
|
||||
s.reportLock.Lock()
|
||||
defer s.reportLock.Unlock()
|
||||
errs := make([]error, 0, len(s.report))
|
||||
for _, report := range s.report {
|
||||
errs = append(errs, report.Errors())
|
||||
}
|
||||
s.doneErr = errors.Join(errs...)
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// Report retrieves a scan report for a given handle. If no report exists or
|
||||
// the Source has not finished, nil will be returned.
|
||||
func (s *SourceManager) Report(handle handle) *JobReport {
|
||||
s.reportLock.Lock()
|
||||
defer s.reportLock.Unlock()
|
||||
return s.report[handle]
|
||||
}
|
||||
|
||||
// preflightChecks is a helper method to check the Manager or the context isn't
|
||||
|
@ -159,23 +219,127 @@ func (s *SourceManager) preflightChecks(ctx context.Context, handle handle) erro
|
|||
}
|
||||
|
||||
// run is a helper method to sychronously run the source. It does not check for
|
||||
// acquired resources.
|
||||
func (s *SourceManager) run(ctx context.Context, handle handle) error {
|
||||
// acquired resources. Possible return values are:
|
||||
//
|
||||
// - *JobReport, nil
|
||||
// Successfully ran the source, but the report could have errors.
|
||||
//
|
||||
// - *JobReport, error
|
||||
// There was an error calling Init or Chunks. This sort of error indicates
|
||||
// a fatal error and is also recorded in the report.
|
||||
//
|
||||
// - nil, error:
|
||||
// There was an error from the API or the handle is invalid. The latter of
|
||||
// which should never happen due to the preflightChecks.
|
||||
func (s *SourceManager) run(ctx context.Context, handle handle) (*JobReport, error) {
|
||||
jobID, err := s.api.GetJobID(ctx, int64(handle))
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
initFunc, ok := s.getInitFunc(handle)
|
||||
if !ok {
|
||||
return fmt.Errorf("unrecognized handle")
|
||||
return nil, fmt.Errorf("unrecognized handle")
|
||||
}
|
||||
// Create a report for this run.
|
||||
report := &JobReport{
|
||||
SourceID: int64(handle),
|
||||
JobID: jobID,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
defer func() { report.EndTime = time.Now() }()
|
||||
|
||||
// Initialize the source.
|
||||
source, err := initFunc(ctx, jobID, int64(handle))
|
||||
if err != nil {
|
||||
return err
|
||||
report.AddError(err)
|
||||
return report, err
|
||||
}
|
||||
// TODO: Support UnitChunker and SourceUnitEnumerator.
|
||||
// TODO: This is where we can introspect on the chunks collected.
|
||||
return source.Chunks(ctx, s.outputChunks)
|
||||
// Check for the preferred method of tracking source units.
|
||||
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && s.useSourceUnits {
|
||||
return s.runWithUnits(ctx, handle, enumChunker, report)
|
||||
}
|
||||
return s.runWithoutUnits(ctx, handle, source, report)
|
||||
}
|
||||
|
||||
// runWithoutUnits is a helper method to run a Source. It has coarse-grained
|
||||
// job reporting.
|
||||
func (s *SourceManager) runWithoutUnits(ctx context.Context, handle handle, source Source, report *JobReport) (*JobReport, error) {
|
||||
// Introspect on the chunks we get from the Chunks method.
|
||||
ch := make(chan *Chunk)
|
||||
var wg sync.WaitGroup
|
||||
// Consume chunks and export chunks.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for chunk := range ch {
|
||||
atomic.AddUint64(&report.TotalChunks, 1)
|
||||
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
|
||||
}
|
||||
}()
|
||||
// Don't return from this function until the goroutine has finished
|
||||
// outputting chunks to the downstream channel. Closing the channel
|
||||
// will stop the goroutine, so that needs to happen first in the defer
|
||||
// stack.
|
||||
defer wg.Wait()
|
||||
defer close(ch)
|
||||
if err := source.Chunks(ctx, ch); err != nil {
|
||||
report.AddError(err)
|
||||
return report, err
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// runWithUnits is a helper method to run a Source that is also a
|
||||
// SourceUnitEnumChunker. This allows better introspection of what is getting
|
||||
// scanned and any errors encountered.
|
||||
func (s *SourceManager) runWithUnits(ctx context.Context, handle handle, source SourceUnitEnumChunker, report *JobReport) (*JobReport, error) {
|
||||
reporter := &mgrUnitReporter{
|
||||
unitCh: make(chan SourceUnit),
|
||||
}
|
||||
// Produce units.
|
||||
go func() {
|
||||
// TODO: Catch panics and add to report.
|
||||
defer close(reporter.unitCh)
|
||||
if err := source.Enumerate(ctx, reporter); err != nil {
|
||||
report.AddError(err)
|
||||
}
|
||||
}()
|
||||
var wg sync.WaitGroup
|
||||
// TODO: Maybe switch to using a semaphore.Weighted.
|
||||
var unitPool errgroup.Group
|
||||
if s.concurrentUnits != 0 {
|
||||
// Negative values indicated no limit.
|
||||
unitPool.SetLimit(s.concurrentUnits)
|
||||
}
|
||||
for unit := range reporter.unitCh {
|
||||
reporter := &mgrChunkReporter{
|
||||
unitID: unit.SourceUnitID(),
|
||||
chunkCh: make(chan *Chunk),
|
||||
}
|
||||
unit := unit
|
||||
// Consume units and produce chunks.
|
||||
unitPool.Go(func() error {
|
||||
// TODO: Catch panics and add to report.
|
||||
defer close(reporter.chunkCh)
|
||||
if err := source.ChunkUnit(ctx, unit, reporter); err != nil {
|
||||
report.AddError(err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// Consume chunks and export chunks.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for chunk := range reporter.chunkCh {
|
||||
// TODO: Introspect on the chunks we got from this unit.
|
||||
atomic.AddUint64(&report.TotalChunks, 1)
|
||||
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
// TODO: Return fatal errors.
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// getInitFunc is a helper method for safe concurrent access to the
|
||||
|
@ -201,3 +365,40 @@ func (api *headlessAPI) RegisterSource(ctx context.Context, name string, kind so
|
|||
func (api *headlessAPI) GetJobID(ctx context.Context, id int64) (int64, error) {
|
||||
return atomic.AddInt64(&api.jobIDCounter, 1), nil
|
||||
}
|
||||
|
||||
// mgrUnitReporter implements the UnitReporter interface.
|
||||
type mgrUnitReporter struct {
|
||||
unitCh chan SourceUnit
|
||||
unitErrs []error
|
||||
unitErrsLock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *mgrUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
|
||||
return common.CancellableWrite(ctx, s.unitCh, unit)
|
||||
}
|
||||
|
||||
func (s *mgrUnitReporter) UnitErr(ctx context.Context, err error) error {
|
||||
s.unitErrsLock.Lock()
|
||||
defer s.unitErrsLock.Unlock()
|
||||
s.unitErrs = append(s.unitErrs, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// mgrChunkReporter implements the ChunkReporter interface.
|
||||
type mgrChunkReporter struct {
|
||||
unitID string
|
||||
chunkCh chan *Chunk
|
||||
chunkErrs []error
|
||||
chunkErrsLock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *mgrChunkReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
|
||||
return common.CancellableWrite(ctx, s.chunkCh, &chunk)
|
||||
}
|
||||
|
||||
func (s *mgrChunkReporter) ChunkErr(ctx context.Context, err error) error {
|
||||
s.chunkErrsLock.Lock()
|
||||
defer s.chunkErrsLock.Unlock()
|
||||
s.chunkErrs = append(s.chunkErrs, err)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ func (d *DummySource) GetProgress() *Progress { return nil }
|
|||
// Interface to easily test different chunking methods.
|
||||
type chunker interface {
|
||||
Chunks(context.Context, chan *Chunk) error
|
||||
ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error
|
||||
Enumerate(ctx context.Context, reporter UnitReporter) error
|
||||
}
|
||||
|
||||
// Chunk method that writes count bytes to the channel before returning.
|
||||
|
@ -45,6 +47,31 @@ func (c *counterChunker) Chunks(_ context.Context, ch chan *Chunk) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// countChunk implements SourceUnit.
|
||||
type countChunk byte
|
||||
|
||||
func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) }
|
||||
|
||||
func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error {
|
||||
for i := 0; i < c.count; i++ {
|
||||
if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *counterChunker) ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error {
|
||||
return reporter.ChunkOk(ctx, Chunk{Data: []byte{byte(unit.(countChunk))}})
|
||||
}
|
||||
|
||||
// Chunk method that always returns an error.
|
||||
type errorChunker struct{ error }
|
||||
|
||||
func (c errorChunker) Chunks(context.Context, chan *Chunk) error { return c }
|
||||
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
|
||||
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
|
||||
|
||||
// enrollDummy is a helper function to enroll a DummySource with a SourceManager.
|
||||
func enrollDummy(mgr *SourceManager, chunkMethod chunker) (handle, error) {
|
||||
return mgr.Enroll(context.Background(), "dummy", 1337,
|
||||
|
@ -117,3 +144,51 @@ func TestSourceManagerWait(t *testing.T) {
|
|||
t.Fatalf("expected scheduling run to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceManagerError(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
handle, err := enrollDummy(mgr, errorChunker{fmt.Errorf("oops")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error enrolling source: %v", err)
|
||||
}
|
||||
// A synchronous run should fail.
|
||||
if err := mgr.Run(context.Background(), handle); err == nil {
|
||||
t.Fatalf("expected run to fail")
|
||||
}
|
||||
// Scheduling a run should not fail, but the error should surface in
|
||||
// Wait().
|
||||
if err := mgr.ScheduleRun(context.Background(), handle); err != nil {
|
||||
t.Fatalf("unexpected error scheduling run: %v", err)
|
||||
}
|
||||
if err := mgr.Wait(); err == nil {
|
||||
t.Fatalf("expected wait to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceManagerReport(t *testing.T) {
|
||||
for _, opts := range [][]func(*SourceManager){
|
||||
{WithBufferedOutput(8)},
|
||||
{WithBufferedOutput(8), WithSourceUnits()},
|
||||
{WithBufferedOutput(8), WithSourceUnits(), WithConcurrentUnits(1)},
|
||||
} {
|
||||
mgr := NewManager(opts...)
|
||||
handle, err := enrollDummy(mgr, &counterChunker{count: 4})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error enrolling source: %v", err)
|
||||
}
|
||||
// Synchronously run the source.
|
||||
if err := mgr.Run(context.Background(), handle); err != nil {
|
||||
t.Fatalf("unexpected error running source: %v", err)
|
||||
}
|
||||
report := mgr.Report(handle)
|
||||
if report == nil {
|
||||
t.Fatalf("expected a report")
|
||||
}
|
||||
if err := report.Errors(); err != nil {
|
||||
t.Fatalf("unexpected error in report: %v", err)
|
||||
}
|
||||
if report.TotalChunks != 4 {
|
||||
t.Fatalf("expected report to have 4 chunks, got: %d", report.TotalChunks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,13 @@ type Source interface {
|
|||
GetProgress() *Progress
|
||||
}
|
||||
|
||||
// SourceUnitEnumChunker are the two required interfaces to support enumerating
|
||||
// and chunking of units.
|
||||
type SourceUnitEnumChunker interface {
|
||||
SourceUnitEnumerator
|
||||
SourceUnitChunker
|
||||
}
|
||||
|
||||
// SourceUnitUnmarshaller defines an optional interface a Source can implement
|
||||
// to support units coming from an external source.
|
||||
type SourceUnitUnmarshaller interface {
|
||||
|
|
Loading…
Reference in a new issue