Make worker multipliers configurable (#3267)

This commit is contained in:
Dustin Decker 2024-09-04 11:36:26 -07:00 committed by GitHub
parent 7eb5b5b12c
commit db0108f731
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -135,6 +135,15 @@ type Config struct {
// that have been detected by multiple detectors.
// By default, it is set to true.
VerificationOverlap bool
// DetectorWorkerMultiplier is used to determine the number of detector workers to spawn.
DetectorWorkerMultiplier int
// NotificationWorkerMultiplier is used to determine the number of notification workers to spawn.
NotificationWorkerMultiplier int
// VerificationOverlapWorkerMultiplier is used to determine the number of verification overlap workers to spawn.
VerificationOverlapWorkerMultiplier int
}
// Engine represents the core scanning engine responsible for detecting secrets in input data.
@ -195,24 +204,34 @@ type Engine struct {
// Note: bad hack only used for testing.
verificationOverlapTracker *verificationOverlapTracker
// detectorWorkerMultiplier is used to calculate the number of detector workers.
detectorWorkerMultiplier int
// notificationWorkerMultiplier is used to calculate the number of notification workers.
notificationWorkerMultiplier int
// verificationOverlapWorkerMultiplier is used to calculate the number of verification overlap workers.
verificationOverlapWorkerMultiplier int
}
// NewEngine creates a new Engine instance with the provided configuration.
func NewEngine(ctx context.Context, cfg *Config) (*Engine, error) {
engine := &Engine{
concurrency: cfg.Concurrency,
decoders: cfg.Decoders,
detectors: cfg.Detectors,
dispatcher: cfg.Dispatcher,
verify: cfg.Verify,
filterUnverified: cfg.FilterUnverified,
filterEntropy: cfg.FilterEntropy,
printAvgDetectorTime: cfg.PrintAvgDetectorTime,
retainFalsePositives: cfg.LogFilteredUnverified,
verificationOverlap: cfg.VerificationOverlap,
sourceManager: cfg.SourceManager,
scanEntireChunk: cfg.ShouldScanEntireChunk,
detectorVerificationOverrides: cfg.DetectorVerificationOverrides,
concurrency: cfg.Concurrency,
decoders: cfg.Decoders,
detectors: cfg.Detectors,
dispatcher: cfg.Dispatcher,
verify: cfg.Verify,
filterUnverified: cfg.FilterUnverified,
filterEntropy: cfg.FilterEntropy,
printAvgDetectorTime: cfg.PrintAvgDetectorTime,
retainFalsePositives: cfg.LogFilteredUnverified,
verificationOverlap: cfg.VerificationOverlap,
sourceManager: cfg.SourceManager,
scanEntireChunk: cfg.ShouldScanEntireChunk,
detectorVerificationOverrides: cfg.DetectorVerificationOverrides,
detectorWorkerMultiplier: cfg.DetectorWorkerMultiplier,
notificationWorkerMultiplier: cfg.NotificationWorkerMultiplier,
verificationOverlapWorkerMultiplier: cfg.VerificationOverlapWorkerMultiplier,
}
if engine.sourceManager == nil {
return nil, fmt.Errorf("source manager is required")
@ -303,7 +322,19 @@ func (e *Engine) setDefaults(ctx context.Context) {
ctx.Logger().Info("No concurrency specified, defaulting to max", "cpu", numCPU)
e.concurrency = numCPU
}
ctx.Logger().V(3).Info("engine started", "workers", e.concurrency)
if e.detectorWorkerMultiplier < 1 {
// bound by net i/o so it's higher than other workers
e.detectorWorkerMultiplier = 8
}
if e.notificationWorkerMultiplier < 1 {
e.notificationWorkerMultiplier = 1
}
if e.verificationOverlapWorkerMultiplier < 1 {
e.verificationOverlapWorkerMultiplier = 1
}
// Default decoders handle common encoding formats.
if len(e.decoders) == 0 {
@ -625,9 +656,10 @@ func (e *Engine) startScannerWorkers(ctx context.Context) {
}
func (e *Engine) startDetectorWorkers(ctx context.Context) {
const detectorWorkerMultiplier = 4
ctx.Logger().V(2).Info("starting detector workers", "count", e.concurrency*detectorWorkerMultiplier)
for worker := uint64(0); worker < uint64(e.concurrency*detectorWorkerMultiplier); worker++ {
numWorkers := e.concurrency * e.detectorWorkerMultiplier
ctx.Logger().V(2).Info("starting detector workers", "count", numWorkers)
for worker := 0; worker < numWorkers; worker++ {
e.wgDetectorWorkers.Add(1)
go func() {
ctx := context.WithValue(ctx, "detector_worker_id", common.RandomID(5))
@ -639,8 +671,10 @@ func (e *Engine) startDetectorWorkers(ctx context.Context) {
}
func (e *Engine) startVerificationOverlapWorkers(ctx context.Context) {
ctx.Logger().V(2).Info("starting verificationOverlap workers", "count", e.concurrency)
for worker := uint64(0); worker < uint64(e.concurrency); worker++ {
numWorkers := e.concurrency * e.verificationOverlapWorkerMultiplier
ctx.Logger().V(2).Info("starting verificationOverlap workers", "count", numWorkers)
for worker := 0; worker < numWorkers; worker++ {
e.verificationOverlapWg.Add(1)
go func() {
ctx := context.WithValue(ctx, "verification_overlap_worker_id", common.RandomID(5))
@ -652,13 +686,10 @@ func (e *Engine) startVerificationOverlapWorkers(ctx context.Context) {
}
func (e *Engine) startNotifierWorkers(ctx context.Context) {
const notifierWorkerRatio = 2
maxNotifierWorkers := 1
if numWorkers := e.concurrency / notifierWorkerRatio; numWorkers > 0 {
maxNotifierWorkers = numWorkers
}
ctx.Logger().V(2).Info("starting notifier workers", "count", maxNotifierWorkers)
for worker := 0; worker < maxNotifierWorkers; worker++ {
numWorkers := e.notificationWorkerMultiplier * e.concurrency
ctx.Logger().V(2).Info("starting notifier workers", "count", numWorkers)
for worker := 0; worker < numWorkers; worker++ {
e.WgNotifier.Add(1)
go func() {
ctx := context.WithValue(ctx, "notifier_worker_id", common.RandomID(5))