Add ability to dynamically scale concurrently running sources (#1790)

* Add ability to dynamically scale concurrently running sources

Refactor SourceManager to use a counting semaphore to allow for
dymanically changing limits. This complicated `Wait() error` which needs
to return the first error encountered. We previously got that for free
using `errgroup.Group`, however now we need to handle that ourselves.
`Wait()` needs to return an error for use in the engine to set the
correct exit code.

* Group third party imports together
This commit is contained in:
Miccah 2023-09-20 16:49:56 -07:00 committed by GitHub
parent a8c89c59b9
commit efa404942a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 24 deletions

1
go.mod
View file

@ -54,6 +54,7 @@ require (
github.com/launchdarkly/go-server-sdk/v6 v6.1.0
github.com/lib/pq v1.10.9
github.com/lrstanley/bubblezone v0.0.0-20221222153816-e95291e2243e
github.com/marusama/semaphore/v2 v2.5.0
github.com/mattn/go-isatty v0.0.18
github.com/mattn/go-sqlite3 v1.14.17
github.com/mholt/archiver/v4 v4.0.0-alpha.8

2
go.sum
View file

@ -491,6 +491,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM=
github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ=
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=

View file

@ -2,10 +2,12 @@ package sources
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/marusama/semaphore/v2"
"golang.org/x/sync/errgroup"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
@ -17,9 +19,8 @@ type SourceManager struct {
api apiClient
hooks []JobProgressHook
// Pool limiting the amount of concurrent sources running.
pool errgroup.Group
poolLimit int
currentRunningCount int32
sem semaphore.Semaphore
wg sync.WaitGroup
// Max number of units to scan concurrently per source.
concurrentUnits int
// Run the sources using source unit enumeration / chunking if available.
@ -27,7 +28,9 @@ type SourceManager struct {
// Downstream chunks channel to be scanned.
outputChunks chan *Chunk
// Set when Wait() returns.
done bool
firstErr chan error
waitErr error
done bool
}
// apiClient is an interface for optionally communicating with an external API.
@ -52,8 +55,7 @@ func WithReportHook(hook JobProgressHook) func(*SourceManager) {
// WithConcurrentSources limits the concurrent number of sources a manager can run.
func WithConcurrentSources(concurrency int) func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.pool.SetLimit(concurrency)
mgr.poolLimit = concurrency
mgr.sem.SetLimit(concurrency)
}
}
@ -79,7 +81,9 @@ func NewManager(opts ...func(*SourceManager)) *SourceManager {
mgr := SourceManager{
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
sem: semaphore.New(runtime.NumCPU()),
outputChunks: make(chan *Chunk),
firstErr: make(chan error, 1),
}
for _, opt := range opts {
opt(&mgr)
@ -98,9 +102,7 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
return s.asyncRun(ctx, sourceName, source)
}
// asyncRun is a helper method to asynchronously run the Source. It calls out
// to the API to get a job ID for this run, creates a JobProgress object, then
// waits for an available goroutine to asynchronously run it.
// asyncRun is a helper method to asynchronously run the Source.
func (s *SourceManager) asyncRun(ctx context.Context, sourceName string, source Source) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
@ -114,17 +116,28 @@ func (s *SourceManager) asyncRun(ctx context.Context, sourceName string, source
// Create a JobProgress object for tracking progress.
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
s.pool.Go(func() error {
atomic.AddInt32(&s.currentRunningCount, 1)
defer atomic.AddInt32(&s.currentRunningCount, -1)
if err := s.sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
}
s.wg.Add(1)
go func() {
defer s.sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"job_id", jobID,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel(nil)
return s.run(ctx, source, progress)
})
if err := s.run(ctx, source, progress); err != nil {
select {
case s.firstErr <- err:
default:
}
}
}()
return progress.Ref(), nil
}
@ -141,13 +154,19 @@ func (s *SourceManager) Chunks() <-chan *Chunk {
func (s *SourceManager) Wait() error {
// Check if the manager has been Waited.
if s.done {
return s.pool.Wait()
return s.waitErr
}
defer close(s.outputChunks)
defer func() { s.done = true }()
s.done = true
// Return the first error returned by run.
return s.pool.Wait()
s.wg.Wait()
select {
case s.waitErr = <-s.firstErr:
default:
}
close(s.outputChunks)
close(s.firstErr)
return s.waitErr
}
// ScanChunk injects a chunk into the output stream of chunks to be scanned.
@ -158,13 +177,29 @@ func (s *SourceManager) ScanChunk(chunk *Chunk) {
}
// AvailableCapacity returns the number of concurrent jobs the manager can
// accommodate at this time. If there is no limit, -1 is returned.
// accommodate at this time.
func (s *SourceManager) AvailableCapacity() int {
if s.poolLimit == 0 {
return -1
}
runCount := atomic.LoadInt32(&s.currentRunningCount)
return s.poolLimit - int(runCount)
return s.sem.GetLimit() - s.sem.GetCount()
}
// MaxConcurrentSources returns the maximum configured limit of concurrent
// sources the manager will run.
func (s *SourceManager) MaxConcurrentSources() int {
return s.sem.GetLimit()
}
// ConcurrentSources returns the current number of concurrently running
// sources.
func (s *SourceManager) ConcurrentSources() int {
return s.sem.GetCount()
}
// SetMaxConcurrentSources sets the maximum number of concurrently running
// sources. If the count is lower than the already existing number of
// concurrently running sources, no sources will be scheduled to run until the
// existing sources complete.
func (s *SourceManager) SetMaxConcurrentSources(maxRunCount int) {
s.sem.SetLimit(maxRunCount)
}
// preflightChecks is a helper method to check the Manager or the context isn't