Initial implementation of JobReport with SourceManager usage (#1557)

* Initial implementation of JobReport with SourceManager usage

* Limit concurrent units

* Only save the last JobReport per handle
This commit is contained in:
Miccah 2023-07-27 10:49:56 -05:00 committed by GitHub
parent 3897454dbb
commit e391e89f3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 339 additions and 19 deletions

37
pkg/sources/job_report.go Normal file
View file

@ -0,0 +1,37 @@
package sources
import (
"errors"
"sync"
"time"
)
// JobReport aggregates information about a run of a Source.
type JobReport struct {
SourceID int64
JobID int64
StartTime time.Time
EndTime time.Time
TotalChunks uint64
errors []error
errorsLock sync.Mutex
}
// AddError adds a non-nil error to the aggregate of errors encountered during
// scanning.
func (jr *JobReport) AddError(err error) {
if err == nil {
return
}
jr.errorsLock.Lock()
defer jr.errorsLock.Unlock()
jr.errors = append(jr.errors, err)
}
// Errors joins all aggregated errors into one. If there were no errors, nil is
// returned. errors.Is can be used to check for specific errors.
func (jr *JobReport) Errors() error {
jr.errorsLock.Lock()
defer jr.errorsLock.Unlock()
return errors.Join(jr.errors...)
}

View file

@ -1,9 +1,11 @@
package sources
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
@ -25,12 +27,20 @@ type SourceManager struct {
// Map of handle to source initializer.
handles map[handle]SourceInitFunc
handlesLock sync.Mutex
// Map of handle to job reports.
// TODO: Manage culling and flushing to the API.
report map[handle]*JobReport
reportLock sync.Mutex
// Pool limiting the amount of concurrent sources running.
pool errgroup.Group
pool errgroup.Group
concurrentUnits int
// Run the sources using source unit enumeration / chunking if available.
useSourceUnits bool
// Downstream chunks channel to be scanned.
outputChunks chan *Chunk
// Set to true when Wait() returns.
done bool
// Set when Wait() returns.
done bool
doneErr error
}
// apiClient is an interface for optionally communicating with an external API.
@ -56,12 +66,25 @@ func WithBufferedOutput(size int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
}
// WithSourceUnits enables using source unit enumeration and chunking if the
// source supports it.
func WithSourceUnits() func(*SourceManager) {
return func(mgr *SourceManager) { mgr.useSourceUnits = true }
}
// WithConcurrentUnits limits the number of units to be scanned concurrently.
// The default is unlimited.
func WithConcurrentUnits(n int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.concurrentUnits = n }
}
// NewManager creates a new manager with the provided options.
func NewManager(opts ...func(*SourceManager)) *SourceManager {
mgr := SourceManager{
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
handles: make(map[handle]SourceInitFunc),
report: make(map[handle]*JobReport),
outputChunks: make(chan *Chunk),
}
for _, opt := range opts {
@ -100,15 +123,24 @@ func (s *SourceManager) Run(ctx context.Context, handle handle) error {
ch := make(chan error)
s.pool.Go(func() error {
defer common.Recover(ctx)
// TODO: The manager should record these errors.
ch <- s.run(ctx, handle)
report, err := s.run(ctx, handle)
if report != nil {
s.reportLock.Lock()
s.report[handle] = report
s.reportLock.Unlock()
}
if err != nil {
ch <- err
return nil
}
ch <- report.Errors()
return nil
})
return <-ch
}
// ScheduleRun blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is lost in this case.
// asynchronously runs it. Error information is stored and returned by Wait().
func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx, handle); err != nil {
@ -116,8 +148,14 @@ func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
}
s.pool.Go(func() error {
defer common.Recover(ctx)
// TODO: The manager should record these errors.
_ = s.run(ctx, handle)
// The error is already saved in the report, so we can ignore
// it here.
report, _ := s.run(ctx, handle)
if report != nil {
s.reportLock.Lock()
s.report[handle] = report
s.reportLock.Unlock()
}
return nil
})
// TODO: Maybe wait for a signal here that initialization was successful?
@ -136,12 +174,34 @@ func (s *SourceManager) Chunks() <-chan *Chunk {
func (s *SourceManager) Wait() error {
// Check if the manager has been Waited.
if s.done {
return nil
return s.doneErr
}
// TODO: Aggregate all errors from all sources.
defer close(s.outputChunks)
defer func() { s.done = true }()
return s.pool.Wait()
// We are only using the errgroup for limiting concurrency.
// TODO: Maybe switch to using a semaphore.Weighted.
_ = s.pool.Wait()
// Aggregate all errors from all job reports.
// TODO: This should probably only be the fatal errors. We'll also need
// to rewrite this for when the reports start getting culled.
s.reportLock.Lock()
defer s.reportLock.Unlock()
errs := make([]error, 0, len(s.report))
for _, report := range s.report {
errs = append(errs, report.Errors())
}
s.doneErr = errors.Join(errs...)
return s.doneErr
}
// Report retrieves a scan report for a given handle. If no report exists or
// the Source has not finished, nil will be returned.
func (s *SourceManager) Report(handle handle) *JobReport {
s.reportLock.Lock()
defer s.reportLock.Unlock()
return s.report[handle]
}
// preflightChecks is a helper method to check the Manager or the context isn't
@ -159,23 +219,127 @@ func (s *SourceManager) preflightChecks(ctx context.Context, handle handle) erro
}
// run is a helper method to sychronously run the source. It does not check for
// acquired resources.
func (s *SourceManager) run(ctx context.Context, handle handle) error {
// acquired resources. Possible return values are:
//
// - *JobReport, nil
// Successfully ran the source, but the report could have errors.
//
// - *JobReport, error
// There was an error calling Init or Chunks. This sort of error indicates
// a fatal error and is also recorded in the report.
//
// - nil, error:
// There was an error from the API or the handle is invalid. The latter of
// which should never happen due to the preflightChecks.
func (s *SourceManager) run(ctx context.Context, handle handle) (*JobReport, error) {
jobID, err := s.api.GetJobID(ctx, int64(handle))
if err != nil {
return err
return nil, err
}
initFunc, ok := s.getInitFunc(handle)
if !ok {
return fmt.Errorf("unrecognized handle")
return nil, fmt.Errorf("unrecognized handle")
}
// Create a report for this run.
report := &JobReport{
SourceID: int64(handle),
JobID: jobID,
StartTime: time.Now(),
}
defer func() { report.EndTime = time.Now() }()
// Initialize the source.
source, err := initFunc(ctx, jobID, int64(handle))
if err != nil {
return err
report.AddError(err)
return report, err
}
// TODO: Support UnitChunker and SourceUnitEnumerator.
// TODO: This is where we can introspect on the chunks collected.
return source.Chunks(ctx, s.outputChunks)
// Check for the preferred method of tracking source units.
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && s.useSourceUnits {
return s.runWithUnits(ctx, handle, enumChunker, report)
}
return s.runWithoutUnits(ctx, handle, source, report)
}
// runWithoutUnits is a helper method to run a Source. It has coarse-grained
// job reporting.
func (s *SourceManager) runWithoutUnits(ctx context.Context, handle handle, source Source, report *JobReport) (*JobReport, error) {
// Introspect on the chunks we get from the Chunks method.
ch := make(chan *Chunk)
var wg sync.WaitGroup
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
for chunk := range ch {
atomic.AddUint64(&report.TotalChunks, 1)
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
}
}()
// Don't return from this function until the goroutine has finished
// outputting chunks to the downstream channel. Closing the channel
// will stop the goroutine, so that needs to happen first in the defer
// stack.
defer wg.Wait()
defer close(ch)
if err := source.Chunks(ctx, ch); err != nil {
report.AddError(err)
return report, err
}
return report, nil
}
// runWithUnits is a helper method to run a Source that is also a
// SourceUnitEnumChunker. This allows better introspection of what is getting
// scanned and any errors encountered.
func (s *SourceManager) runWithUnits(ctx context.Context, handle handle, source SourceUnitEnumChunker, report *JobReport) (*JobReport, error) {
reporter := &mgrUnitReporter{
unitCh: make(chan SourceUnit),
}
// Produce units.
go func() {
// TODO: Catch panics and add to report.
defer close(reporter.unitCh)
if err := source.Enumerate(ctx, reporter); err != nil {
report.AddError(err)
}
}()
var wg sync.WaitGroup
// TODO: Maybe switch to using a semaphore.Weighted.
var unitPool errgroup.Group
if s.concurrentUnits != 0 {
// Negative values indicated no limit.
unitPool.SetLimit(s.concurrentUnits)
}
for unit := range reporter.unitCh {
reporter := &mgrChunkReporter{
unitID: unit.SourceUnitID(),
chunkCh: make(chan *Chunk),
}
unit := unit
// Consume units and produce chunks.
unitPool.Go(func() error {
// TODO: Catch panics and add to report.
defer close(reporter.chunkCh)
if err := source.ChunkUnit(ctx, unit, reporter); err != nil {
report.AddError(err)
}
return nil
})
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
for chunk := range reporter.chunkCh {
// TODO: Introspect on the chunks we got from this unit.
atomic.AddUint64(&report.TotalChunks, 1)
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
}
}()
}
wg.Wait()
// TODO: Return fatal errors.
return report, nil
}
// getInitFunc is a helper method for safe concurrent access to the
@ -201,3 +365,40 @@ func (api *headlessAPI) RegisterSource(ctx context.Context, name string, kind so
func (api *headlessAPI) GetJobID(ctx context.Context, id int64) (int64, error) {
return atomic.AddInt64(&api.jobIDCounter, 1), nil
}
// mgrUnitReporter implements the UnitReporter interface.
type mgrUnitReporter struct {
unitCh chan SourceUnit
unitErrs []error
unitErrsLock sync.Mutex
}
func (s *mgrUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
return common.CancellableWrite(ctx, s.unitCh, unit)
}
func (s *mgrUnitReporter) UnitErr(ctx context.Context, err error) error {
s.unitErrsLock.Lock()
defer s.unitErrsLock.Unlock()
s.unitErrs = append(s.unitErrs, err)
return nil
}
// mgrChunkReporter implements the ChunkReporter interface.
type mgrChunkReporter struct {
unitID string
chunkCh chan *Chunk
chunkErrs []error
chunkErrsLock sync.Mutex
}
func (s *mgrChunkReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
return common.CancellableWrite(ctx, s.chunkCh, &chunk)
}
func (s *mgrChunkReporter) ChunkErr(ctx context.Context, err error) error {
s.chunkErrsLock.Lock()
defer s.chunkErrsLock.Unlock()
s.chunkErrs = append(s.chunkErrs, err)
return nil
}

View file

@ -29,6 +29,8 @@ func (d *DummySource) GetProgress() *Progress { return nil }
// Interface to easily test different chunking methods.
type chunker interface {
Chunks(context.Context, chan *Chunk) error
ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error
Enumerate(ctx context.Context, reporter UnitReporter) error
}
// Chunk method that writes count bytes to the channel before returning.
@ -45,6 +47,31 @@ func (c *counterChunker) Chunks(_ context.Context, ch chan *Chunk) error {
return nil
}
// countChunk implements SourceUnit.
type countChunk byte
func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) }
func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error {
for i := 0; i < c.count; i++ {
if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil {
return err
}
}
return nil
}
func (c *counterChunker) ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error {
return reporter.ChunkOk(ctx, Chunk{Data: []byte{byte(unit.(countChunk))}})
}
// Chunk method that always returns an error.
type errorChunker struct{ error }
func (c errorChunker) Chunks(context.Context, chan *Chunk) error { return c }
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
// enrollDummy is a helper function to enroll a DummySource with a SourceManager.
func enrollDummy(mgr *SourceManager, chunkMethod chunker) (handle, error) {
return mgr.Enroll(context.Background(), "dummy", 1337,
@ -117,3 +144,51 @@ func TestSourceManagerWait(t *testing.T) {
t.Fatalf("expected scheduling run to fail")
}
}
func TestSourceManagerError(t *testing.T) {
mgr := NewManager()
handle, err := enrollDummy(mgr, errorChunker{fmt.Errorf("oops")})
if err != nil {
t.Fatalf("unexpected error enrolling source: %v", err)
}
// A synchronous run should fail.
if err := mgr.Run(context.Background(), handle); err == nil {
t.Fatalf("expected run to fail")
}
// Scheduling a run should not fail, but the error should surface in
// Wait().
if err := mgr.ScheduleRun(context.Background(), handle); err != nil {
t.Fatalf("unexpected error scheduling run: %v", err)
}
if err := mgr.Wait(); err == nil {
t.Fatalf("expected wait to fail")
}
}
func TestSourceManagerReport(t *testing.T) {
for _, opts := range [][]func(*SourceManager){
{WithBufferedOutput(8)},
{WithBufferedOutput(8), WithSourceUnits()},
{WithBufferedOutput(8), WithSourceUnits(), WithConcurrentUnits(1)},
} {
mgr := NewManager(opts...)
handle, err := enrollDummy(mgr, &counterChunker{count: 4})
if err != nil {
t.Fatalf("unexpected error enrolling source: %v", err)
}
// Synchronously run the source.
if err := mgr.Run(context.Background(), handle); err != nil {
t.Fatalf("unexpected error running source: %v", err)
}
report := mgr.Report(handle)
if report == nil {
t.Fatalf("expected a report")
}
if err := report.Errors(); err != nil {
t.Fatalf("unexpected error in report: %v", err)
}
if report.TotalChunks != 4 {
t.Fatalf("expected report to have 4 chunks, got: %d", report.TotalChunks)
}
}
}

View file

@ -44,6 +44,13 @@ type Source interface {
GetProgress() *Progress
}
// SourceUnitEnumChunker are the two required interfaces to support enumerating
// and chunking of units.
type SourceUnitEnumChunker interface {
SourceUnitEnumerator
SourceUnitChunker
}
// SourceUnitUnmarshaller defines an optional interface a Source can implement
// to support units coming from an external source.
type SourceUnitUnmarshaller interface {