mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Add ability to dynamically scale concurrently running sources (#1790)
* Add ability to dynamically scale concurrently running sources Refactor SourceManager to use a counting semaphore to allow for dymanically changing limits. This complicated `Wait() error` which needs to return the first error encountered. We previously got that for free using `errgroup.Group`, however now we need to handle that ourselves. `Wait()` needs to return an error for use in the engine to set the correct exit code. * Group third party imports together
This commit is contained in:
parent
a8c89c59b9
commit
efa404942a
3 changed files with 62 additions and 24 deletions
1
go.mod
1
go.mod
|
@ -54,6 +54,7 @@ require (
|
|||
github.com/launchdarkly/go-server-sdk/v6 v6.1.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/lrstanley/bubblezone v0.0.0-20221222153816-e95291e2243e
|
||||
github.com/marusama/semaphore/v2 v2.5.0
|
||||
github.com/mattn/go-isatty v0.0.18
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
github.com/mholt/archiver/v4 v4.0.0-alpha.8
|
||||
|
|
2
go.sum
2
go.sum
|
@ -491,6 +491,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
|
|||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM=
|
||||
github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ=
|
||||
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
|
||||
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
|
|
|
@ -2,10 +2,12 @@ package sources
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/marusama/semaphore/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
|
@ -17,9 +19,8 @@ type SourceManager struct {
|
|||
api apiClient
|
||||
hooks []JobProgressHook
|
||||
// Pool limiting the amount of concurrent sources running.
|
||||
pool errgroup.Group
|
||||
poolLimit int
|
||||
currentRunningCount int32
|
||||
sem semaphore.Semaphore
|
||||
wg sync.WaitGroup
|
||||
// Max number of units to scan concurrently per source.
|
||||
concurrentUnits int
|
||||
// Run the sources using source unit enumeration / chunking if available.
|
||||
|
@ -27,7 +28,9 @@ type SourceManager struct {
|
|||
// Downstream chunks channel to be scanned.
|
||||
outputChunks chan *Chunk
|
||||
// Set when Wait() returns.
|
||||
done bool
|
||||
firstErr chan error
|
||||
waitErr error
|
||||
done bool
|
||||
}
|
||||
|
||||
// apiClient is an interface for optionally communicating with an external API.
|
||||
|
@ -52,8 +55,7 @@ func WithReportHook(hook JobProgressHook) func(*SourceManager) {
|
|||
// WithConcurrentSources limits the concurrent number of sources a manager can run.
|
||||
func WithConcurrentSources(concurrency int) func(*SourceManager) {
|
||||
return func(mgr *SourceManager) {
|
||||
mgr.pool.SetLimit(concurrency)
|
||||
mgr.poolLimit = concurrency
|
||||
mgr.sem.SetLimit(concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +81,9 @@ func NewManager(opts ...func(*SourceManager)) *SourceManager {
|
|||
mgr := SourceManager{
|
||||
// Default to the headless API. Can be overwritten by the WithAPI option.
|
||||
api: &headlessAPI{},
|
||||
sem: semaphore.New(runtime.NumCPU()),
|
||||
outputChunks: make(chan *Chunk),
|
||||
firstErr: make(chan error, 1),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&mgr)
|
||||
|
@ -98,9 +102,7 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
|
|||
return s.asyncRun(ctx, sourceName, source)
|
||||
}
|
||||
|
||||
// asyncRun is a helper method to asynchronously run the Source. It calls out
|
||||
// to the API to get a job ID for this run, creates a JobProgress object, then
|
||||
// waits for an available goroutine to asynchronously run it.
|
||||
// asyncRun is a helper method to asynchronously run the Source.
|
||||
func (s *SourceManager) asyncRun(ctx context.Context, sourceName string, source Source) (JobProgressRef, error) {
|
||||
sourceID, jobID := source.SourceID(), source.JobID()
|
||||
// Do preflight checks before waiting on the pool.
|
||||
|
@ -114,17 +116,28 @@ func (s *SourceManager) asyncRun(ctx context.Context, sourceName string, source
|
|||
// Create a JobProgress object for tracking progress.
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
|
||||
s.pool.Go(func() error {
|
||||
atomic.AddInt32(&s.currentRunningCount, 1)
|
||||
defer atomic.AddInt32(&s.currentRunningCount, -1)
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
// Context cancelled.
|
||||
progress.ReportError(Fatal{err})
|
||||
return progress.Ref(), Fatal{err}
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.sem.Release(1)
|
||||
defer s.wg.Done()
|
||||
ctx := context.WithValues(ctx,
|
||||
"job_id", jobID,
|
||||
"source_manager_worker_id", common.RandomID(5),
|
||||
)
|
||||
defer common.Recover(ctx)
|
||||
defer cancel(nil)
|
||||
return s.run(ctx, source, progress)
|
||||
})
|
||||
if err := s.run(ctx, source, progress); err != nil {
|
||||
select {
|
||||
case s.firstErr <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
return progress.Ref(), nil
|
||||
}
|
||||
|
||||
|
@ -141,13 +154,19 @@ func (s *SourceManager) Chunks() <-chan *Chunk {
|
|||
func (s *SourceManager) Wait() error {
|
||||
// Check if the manager has been Waited.
|
||||
if s.done {
|
||||
return s.pool.Wait()
|
||||
return s.waitErr
|
||||
}
|
||||
defer close(s.outputChunks)
|
||||
defer func() { s.done = true }()
|
||||
s.done = true
|
||||
|
||||
// Return the first error returned by run.
|
||||
return s.pool.Wait()
|
||||
s.wg.Wait()
|
||||
select {
|
||||
case s.waitErr = <-s.firstErr:
|
||||
default:
|
||||
}
|
||||
close(s.outputChunks)
|
||||
close(s.firstErr)
|
||||
return s.waitErr
|
||||
}
|
||||
|
||||
// ScanChunk injects a chunk into the output stream of chunks to be scanned.
|
||||
|
@ -158,13 +177,29 @@ func (s *SourceManager) ScanChunk(chunk *Chunk) {
|
|||
}
|
||||
|
||||
// AvailableCapacity returns the number of concurrent jobs the manager can
|
||||
// accommodate at this time. If there is no limit, -1 is returned.
|
||||
// accommodate at this time.
|
||||
func (s *SourceManager) AvailableCapacity() int {
|
||||
if s.poolLimit == 0 {
|
||||
return -1
|
||||
}
|
||||
runCount := atomic.LoadInt32(&s.currentRunningCount)
|
||||
return s.poolLimit - int(runCount)
|
||||
return s.sem.GetLimit() - s.sem.GetCount()
|
||||
}
|
||||
|
||||
// MaxConcurrentSources returns the maximum configured limit of concurrent
|
||||
// sources the manager will run.
|
||||
func (s *SourceManager) MaxConcurrentSources() int {
|
||||
return s.sem.GetLimit()
|
||||
}
|
||||
|
||||
// ConcurrentSources returns the current number of concurrently running
|
||||
// sources.
|
||||
func (s *SourceManager) ConcurrentSources() int {
|
||||
return s.sem.GetCount()
|
||||
}
|
||||
|
||||
// SetMaxConcurrentSources sets the maximum number of concurrently running
|
||||
// sources. If the count is lower than the already existing number of
|
||||
// concurrently running sources, no sources will be scheduled to run until the
|
||||
// existing sources complete.
|
||||
func (s *SourceManager) SetMaxConcurrentSources(maxRunCount int) {
|
||||
s.sem.SetLimit(maxRunCount)
|
||||
}
|
||||
|
||||
// preflightChecks is a helper method to check the Manager or the context isn't
|
||||
|
|
Loading…
Reference in a new issue