mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
add priority semaphore (#2336)
This commit is contained in:
parent
792266afa9
commit
f209b04d5d
1 changed files with 17 additions and 4 deletions
|
@ -19,8 +19,9 @@ type SourceManager struct {
|
|||
api apiClient
|
||||
hooks []JobProgressHook
|
||||
// Pool limiting the amount of concurrent sources running.
|
||||
sem semaphore.Semaphore
|
||||
wg sync.WaitGroup
|
||||
sem semaphore.Semaphore
|
||||
prioritySem semaphore.Semaphore
|
||||
wg sync.WaitGroup
|
||||
// Max number of units to scan concurrently per source.
|
||||
concurrentUnits int
|
||||
// Run the sources using source unit enumeration / chunking if available.
|
||||
|
@ -60,6 +61,13 @@ func WithConcurrentSources(concurrency int) func(*SourceManager) {
|
|||
}
|
||||
}
|
||||
|
||||
// WithConcurrentTargets limits the concurrent number of targets a manager can run.
|
||||
func WithConcurrentTargets(concurrency int) func(*SourceManager) {
|
||||
return func(mgr *SourceManager) {
|
||||
mgr.prioritySem.SetLimit(concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
// WithBufferedOutput sets the size of the buffer used for the Chunks() channel.
|
||||
func WithBufferedOutput(size int) func(*SourceManager) {
|
||||
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
|
||||
|
@ -89,6 +97,7 @@ func NewManager(opts ...func(*SourceManager)) *SourceManager {
|
|||
// Default to the headless API. Can be overwritten by the WithAPI option.
|
||||
api: &headlessAPI{},
|
||||
sem: semaphore.New(runtime.NumCPU()),
|
||||
prioritySem: semaphore.New(runtime.NumCPU()),
|
||||
outputChunks: make(chan *Chunk),
|
||||
firstErr: make(chan error, 1),
|
||||
}
|
||||
|
@ -116,9 +125,13 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
|
|||
}, err
|
||||
}
|
||||
// Create a JobProgress object for tracking progress.
|
||||
sem := s.sem
|
||||
if len(targets) > 0 {
|
||||
sem = s.prioritySem
|
||||
}
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
// Context cancelled.
|
||||
progress.ReportError(Fatal{err})
|
||||
return progress.Ref(), Fatal{err}
|
||||
|
@ -127,7 +140,7 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
|
|||
go func() {
|
||||
// Call Finish after the semaphore has been released.
|
||||
defer progress.Finish()
|
||||
defer s.sem.Release(1)
|
||||
defer sem.Release(1)
|
||||
defer s.wg.Done()
|
||||
ctx := context.WithValues(ctx,
|
||||
"source_manager_worker_id", common.RandomID(5),
|
||||
|
|
Loading…
Reference in a new issue