trufflehog/pkg/sources/source_manager.go

427 lines
13 KiB
Go

package sources
import (
"fmt"
"io"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/marusama/semaphore/v2"
"golang.org/x/sync/errgroup"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
)
type SourceManager struct {
api apiClient
hooks []JobProgressHook
// Pool limiting the amount of concurrent sources running.
sem semaphore.Semaphore
prioritySem semaphore.Semaphore
wg sync.WaitGroup
// Max number of units to scan concurrently per source.
concurrentUnits int
// Run the sources using source unit enumeration / chunking if available.
// Checked at runtime to allow feature flagging.
useSourceUnitsFunc func() bool
// Downstream chunks channel to be scanned.
outputChunks chan *Chunk
// Set when Wait() returns.
firstErr chan error
waitErr error
done bool
}
// apiClient is an interface for optionally communicating with an external API.
type apiClient interface {
// GetIDs informs the API of the source that's about to run and returns
// two identifiers used during source initialization.
GetIDs(ctx context.Context, name string, kind sourcespb.SourceType) (SourceID, JobID, error)
}
// WithAPI adds an API client to the manager for tracking jobs and progress. If
// the API is also a JobProgressHook, it will be added to the list of event hooks.
func WithAPI(api apiClient) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.api = api }
}
func WithReportHook(hook JobProgressHook) func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.hooks = append(mgr.hooks, hook)
}
}
// WithConcurrentSources limits the concurrent number of sources a manager can run.
func WithConcurrentSources(concurrency int) func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.sem.SetLimit(concurrency)
}
}
// WithConcurrentTargets limits the concurrent number of targets a manager can run.
func WithConcurrentTargets(concurrency int) func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.prioritySem.SetLimit(concurrency)
}
}
// WithBufferedOutput sets the size of the buffer used for the Chunks() channel.
func WithBufferedOutput(size int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
}
// WithSourceUnits enables using source unit enumeration and chunking if the
// source supports it.
func WithSourceUnits() func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.useSourceUnitsFunc = func() bool { return true }
}
}
func WithSourceUnitsFunc(f func() bool) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.useSourceUnitsFunc = f }
}
// WithConcurrentUnits limits the number of units to be scanned concurrently.
// The default is unlimited.
func WithConcurrentUnits(n int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.concurrentUnits = n }
}
// NewManager creates a new manager with the provided options.
func NewManager(opts ...func(*SourceManager)) *SourceManager {
mgr := SourceManager{
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
sem: semaphore.New(runtime.NumCPU()),
prioritySem: semaphore.New(runtime.NumCPU()),
outputChunks: make(chan *Chunk),
firstErr: make(chan error, 1),
}
for _, opt := range opts {
opt(&mgr)
}
return &mgr
}
func (s *SourceManager) GetIDs(ctx context.Context, sourceName string, kind sourcespb.SourceType) (SourceID, JobID, error) {
return s.api.GetIDs(ctx, sourceName, kind)
}
// Run blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is stored and accessible via the
// JobProgressRef as it becomes available.
func (s *SourceManager) Run(ctx context.Context, sourceName string, source Source, targets ...ChunkingTarget) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
return JobProgressRef{
SourceName: sourceName,
SourceID: sourceID,
JobID: jobID,
}, err
}
// Create a JobProgress object for tracking progress.
sem := s.sem
if len(targets) > 0 {
sem = s.prioritySem
}
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
if err := sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
}
s.wg.Add(1)
go func() {
// Call Finish after the semaphore has been released.
defer progress.Finish()
defer sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel(nil)
if err := s.run(ctx, source, progress, targets...); err != nil {
select {
case s.firstErr <- err:
default:
}
}
}()
return progress.Ref(), nil
}
// Chunks returns the read only channel of all the chunks produced by all of
// the sources managed by this manager.
func (s *SourceManager) Chunks() <-chan *Chunk {
return s.outputChunks
}
// Wait blocks until all running sources are completed and closes the channel
// returned by Chunks(). The manager should not be reused after calling this
// method. This current implementation is not thread safe and should only be
// called by one thread.
func (s *SourceManager) Wait() error {
// Check if the manager has been Waited.
if s.done {
return s.waitErr
}
s.done = true
// Return the first error returned by run.
s.wg.Wait()
select {
case s.waitErr = <-s.firstErr:
default:
}
close(s.outputChunks)
close(s.firstErr)
for _, hook := range s.hooks {
if hookCloser, ok := hook.(io.Closer); ok {
_ = hookCloser.Close()
}
}
return s.waitErr
}
// ScanChunk injects a chunk into the output stream of chunks to be scanned.
// This method should rarely be used. TODO(THOG-1577): Remove when dependencies
// no longer rely on this functionality.
func (s *SourceManager) ScanChunk(chunk *Chunk) {
s.outputChunks <- chunk
}
// AvailableCapacity returns the number of concurrent jobs the manager can
// accommodate at this time.
func (s *SourceManager) AvailableCapacity() int {
return s.sem.GetLimit() - s.sem.GetCount()
}
// MaxConcurrentSources returns the maximum configured limit of concurrent
// sources the manager will run.
func (s *SourceManager) MaxConcurrentSources() int {
return s.sem.GetLimit()
}
// ConcurrentSources returns the current number of concurrently running
// sources.
func (s *SourceManager) ConcurrentSources() int {
return s.sem.GetCount()
}
// SetMaxConcurrentSources sets the maximum number of concurrently running
// sources. If the count is lower than the already existing number of
// concurrently running sources, no sources will be scheduled to run until the
// existing sources complete.
func (s *SourceManager) SetMaxConcurrentSources(maxRunCount int) {
s.sem.SetLimit(maxRunCount)
}
// preflightChecks is a helper method to check the Manager or the context isn't
// done.
func (s *SourceManager) preflightChecks(ctx context.Context) error {
// Check if the manager has been Waited.
if s.done {
return fmt.Errorf("manager is done")
}
return ctx.Err()
}
// run is a helper method to sychronously run the source. It does not check for
// acquired resources. An error is returned if there was a fatal error during
// the run. This information is also recorded in the JobProgress.
func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgress, targets ...ChunkingTarget) error {
report.Start(time.Now())
defer func() { report.End(time.Now()) }()
defer func() {
if err := context.Cause(ctx); err != nil {
report.ReportError(Fatal{err})
}
}()
report.TrackProgress(source.GetProgress())
if ctx.Value("job_id") == "" {
ctx = context.WithValue(ctx, "job_id", report.JobID)
}
if ctx.Value("source_id") == "" {
ctx = context.WithValue(ctx, "source_id", report.SourceID)
}
if ctx.Value("source_name") == "" {
ctx = context.WithValue(ctx, "source_name", report.SourceName)
}
if ctx.Value("source_type") == "" {
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}
// Check for the preferred method of tracking source units.
canUseSourceUnits := len(targets) == 0 && s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
"with_units", true)
return s.runWithUnits(ctx, enumChunker, report)
}
ctx.Logger().Info("running source",
"with_units", false,
"target_count", len(targets),
"source_manager_units_configurable", s.useSourceUnitsFunc != nil)
return s.runWithoutUnits(ctx, source, report, targets...)
}
// runWithoutUnits is a helper method to run a Source. It has coarse-grained
// job reporting.
func (s *SourceManager) runWithoutUnits(ctx context.Context, source Source, report *JobProgress, targets ...ChunkingTarget) error {
// Introspect on the chunks we get from the Chunks method.
ch := make(chan *Chunk, 1)
var wg sync.WaitGroup
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
for chunk := range ch {
chunk.JobID = source.JobID()
report.ReportChunk(nil, chunk)
s.outputChunks <- chunk
}
}()
// Don't return from this function until the goroutine has finished
// outputting chunks to the downstream channel. Closing the channel
// will stop the goroutine, so that needs to happen first in the defer
// stack.
defer wg.Wait()
defer close(ch)
if err := source.Chunks(ctx, ch, targets...); err != nil {
report.ReportError(Fatal{err})
return Fatal{err}
}
return nil
}
// runWithUnits is a helper method to run a Source that is also a
// SourceUnitEnumChunker. This allows better introspection of what is getting
// scanned and any errors encountered.
func (s *SourceManager) runWithUnits(ctx context.Context, source SourceUnitEnumChunker, report *JobProgress) error {
unitReporter := &mgrUnitReporter{
unitCh: make(chan SourceUnit, 1),
report: report,
}
// Create a function that will save the first error encountered (if
// any) and discard the rest.
fatalErr := make(chan error, 1)
catchFirstFatal := func(err error) {
select {
case fatalErr <- err:
default:
}
}
// Produce units.
go func() {
// TODO: Catch panics and add to report.
report.StartEnumerating(time.Now())
defer func() { report.EndEnumerating(time.Now()) }()
defer close(unitReporter.unitCh)
ctx.Logger().V(2).Info("enumerating source")
if err := source.Enumerate(ctx, unitReporter); err != nil {
report.ReportError(Fatal{err})
catchFirstFatal(Fatal{err})
}
}()
var wg sync.WaitGroup
// TODO: Maybe switch to using a semaphore.Weighted.
var unitPool errgroup.Group
if s.concurrentUnits != 0 {
// Negative values indicated no limit.
unitPool.SetLimit(s.concurrentUnits)
}
for unit := range unitReporter.unitCh {
unit := unit
chunkReporter := &mgrChunkReporter{
unit: unit,
chunkCh: make(chan *Chunk, 1),
report: report,
}
// Consume units and produce chunks.
unitPool.Go(func() error {
report.StartUnitChunking(unit, time.Now())
// TODO: Catch panics and add to report.
defer close(chunkReporter.chunkCh)
ctx := context.WithValue(ctx, "unit", unit.SourceUnitID())
ctx.Logger().V(3).Info("chunking unit")
if err := source.ChunkUnit(ctx, unit, chunkReporter); err != nil {
report.ReportError(Fatal{ChunkError{Unit: unit, Err: err}})
catchFirstFatal(Fatal{err})
}
return nil
})
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
defer func() { report.EndUnitChunking(unit, time.Now()) }()
for chunk := range chunkReporter.chunkCh {
if src, ok := source.(Source); ok {
chunk.JobID = src.JobID()
}
s.outputChunks <- chunk
}
}()
}
wg.Wait()
select {
case err := <-fatalErr:
return err
default:
return nil
}
}
// headlessAPI implements the apiClient interface locally.
type headlessAPI struct {
// Counters for assigning source and job IDs.
sourceIDCounter int64
jobIDCounter int64
}
func (api *headlessAPI) GetIDs(context.Context, string, sourcespb.SourceType) (SourceID, JobID, error) {
return SourceID(atomic.AddInt64(&api.sourceIDCounter, 1)), JobID(atomic.AddInt64(&api.jobIDCounter, 1)), nil
}
// mgrUnitReporter implements the UnitReporter interface.
type mgrUnitReporter struct {
unitCh chan SourceUnit
report *JobProgress
}
func (s *mgrUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
s.report.ReportUnit(unit)
return common.CancellableWrite(ctx, s.unitCh, unit)
}
func (s *mgrUnitReporter) UnitErr(ctx context.Context, err error) error {
s.report.ReportError(err)
return nil
}
// mgrChunkReporter implements the ChunkReporter interface.
type mgrChunkReporter struct {
unit SourceUnit
chunkCh chan *Chunk
report *JobProgress
}
func (s *mgrChunkReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
s.report.ReportChunk(s.unit, &chunk)
return common.CancellableWrite(ctx, s.chunkCh, &chunk)
}
func (s *mgrChunkReporter) ChunkErr(ctx context.Context, err error) error {
s.report.ReportError(ChunkError{s.unit, err})
return nil
}