add priority semaphore (#2336)

This commit is contained in:
ahrav 2024-01-24 16:43:56 -08:00 committed by GitHub
parent 792266afa9
commit f209b04d5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,8 +19,9 @@ type SourceManager struct {
api apiClient
hooks []JobProgressHook
// Pool limiting the amount of concurrent sources running.
sem semaphore.Semaphore
wg sync.WaitGroup
sem semaphore.Semaphore
prioritySem semaphore.Semaphore
wg sync.WaitGroup
// Max number of units to scan concurrently per source.
concurrentUnits int
// Run the sources using source unit enumeration / chunking if available.
@ -60,6 +61,13 @@ func WithConcurrentSources(concurrency int) func(*SourceManager) {
}
}
// WithConcurrentTargets limits the concurrent number of targets a manager can run.
func WithConcurrentTargets(concurrency int) func(*SourceManager) {
return func(mgr *SourceManager) {
mgr.prioritySem.SetLimit(concurrency)
}
}
// WithBufferedOutput sets the size of the buffer used for the Chunks() channel.
func WithBufferedOutput(size int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
@ -89,6 +97,7 @@ func NewManager(opts ...func(*SourceManager)) *SourceManager {
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
sem: semaphore.New(runtime.NumCPU()),
prioritySem: semaphore.New(runtime.NumCPU()),
outputChunks: make(chan *Chunk),
firstErr: make(chan error, 1),
}
@ -116,9 +125,13 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
}, err
}
// Create a JobProgress object for tracking progress.
sem := s.sem
if len(targets) > 0 {
sem = s.prioritySem
}
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
if err := s.sem.Acquire(ctx, 1); err != nil {
if err := sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
@ -127,7 +140,7 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
go func() {
// Call Finish after the semaphore has been released.
defer progress.Finish()
defer s.sem.Release(1)
defer sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"source_manager_worker_id", common.RandomID(5),