diff --git a/pkg/sources/job_progress.go b/pkg/sources/job_progress.go index 029e758fa..430804d13 100644 --- a/pkg/sources/job_progress.go +++ b/pkg/sources/job_progress.go @@ -188,7 +188,8 @@ func (jp *JobProgress) executeHooks(todo func(hook JobProgressHook)) { hooksExecTime.WithLabelValues().Observe(float64(elapsed)) }(time.Now()) for _, hook := range jp.hooks { - // TODO: Non-blocking? + // Execute hooks synchronously so they can provide + // back-pressure to the source. todo(hook) } } diff --git a/pkg/sources/job_progress_hook.go b/pkg/sources/job_progress_hook.go index b1eeec932..0f6c45810 100644 --- a/pkg/sources/job_progress_hook.go +++ b/pkg/sources/job_progress_hook.go @@ -3,44 +3,61 @@ package sources import ( "errors" "fmt" - "strings" + "runtime" "sync" "time" - lru "github.com/hashicorp/golang-lru/v2" "github.com/trufflesecurity/trufflehog/v3/pkg/context" ) // UnitHook implements JobProgressHook for tracking the progress of each // individual unit. type UnitHook struct { - metrics *lru.Cache[string, *UnitMetrics] - mu sync.Mutex + metrics map[string]*UnitMetrics + mu sync.Mutex + finishedMetrics chan UnitMetrics + logBackPressure func() NoopHook } type UnitHookOpt func(*UnitHook) -func WithUnitHookCache(cache *lru.Cache[string, *UnitMetrics]) UnitHookOpt { - return func(hook *UnitHook) { hook.metrics = cache } +// WithUnitHookFinishBufferSize sets the buffer size for handling finished +// metrics (default is 1024). If the buffer fills, then scanning will stop +// until there is room. +func WithUnitHookFinishBufferSize(buf int) UnitHookOpt { + return func(hook *UnitHook) { + hook.finishedMetrics = make(chan UnitMetrics, buf) + } } -func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) *UnitHook { - // lru.NewWithEvict can only fail if the size is < 0. - cache, _ := lru.NewWithEvict(1024, func(key string, value *UnitMetrics) { - if value.handled { - return - } - ctx.Logger().Error(fmt.Errorf("eviction"), "dropping unit metric", - "id", key, - "metric", value, - ) - }) - hook := UnitHook{metrics: cache} +func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) (*UnitHook, <-chan UnitMetrics) { + var once sync.Once + hook := UnitHook{ + metrics: make(map[string]*UnitMetrics, runtime.NumCPU()), + finishedMetrics: make(chan UnitMetrics, 1024), + logBackPressure: func() { + once.Do(func() { + ctx.Logger().Info("back pressure detected in unit hook") + }) + }, + } for _, opt := range opts { opt(&hook) } - return &hook + go func() { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hooksChannelSize.WithLabelValues().Set(float64(len(hook.finishedMetrics))) + case <-ctx.Done(): + return + } + } + }() + return &hook, hook.finishedMetrics } // id is a helper method to generate an ID for the given job and unit. @@ -52,28 +69,51 @@ func (u *UnitHook) id(ref JobProgressRef, unit SourceUnit) string { return fmt.Sprintf("%d/%d/%s", ref.SourceID, ref.JobID, unitID) } +func (u *UnitHook) ejectFinishedMetrics(metrics UnitMetrics) { + // Intentionally block the hook from returning to supply back-pressure + // to the source. + select { + case u.finishedMetrics <- metrics: + return + default: + u.logBackPressure() + } + u.finishedMetrics <- metrics +} + func (u *UnitHook) StartUnitChunking(ref JobProgressRef, unit SourceUnit, start time.Time) { id := u.id(ref, unit) u.mu.Lock() defer u.mu.Unlock() - u.metrics.Add(id, &UnitMetrics{ + u.metrics[id] = &UnitMetrics{ Unit: unit, Parent: ref, StartTime: &start, - }) + } } func (u *UnitHook) EndUnitChunking(ref JobProgressRef, unit SourceUnit, end time.Time) { id := u.id(ref, unit) - u.mu.Lock() - defer u.mu.Unlock() - metrics, ok := u.metrics.Get(id) + metrics, ok := u.finishUnit(id) if !ok { return } metrics.EndTime = &end + u.ejectFinishedMetrics(*metrics) +} + +func (u *UnitHook) finishUnit(id string) (*UnitMetrics, bool) { + u.mu.Lock() + defer u.mu.Unlock() + + metrics, ok := u.metrics[id] + if !ok { + return nil, false + } + delete(u.metrics, id) + return metrics, true } func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk) { @@ -81,7 +121,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk u.mu.Lock() defer u.mu.Unlock() - metrics, ok := u.metrics.Get(id) + metrics, ok := u.metrics[id] if !ok && unit != nil { // The unit has been evicted. return @@ -92,7 +132,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk Parent: ref, StartTime: ref.Snapshot().StartTime, } - u.metrics.Add(id, metrics) + u.metrics[id] = metrics } metrics.TotalChunks++ metrics.TotalBytes += uint64(len(chunk.Data)) @@ -103,7 +143,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) { defer u.mu.Unlock() // Always add the error to the nil unit if it exists. - if metrics, ok := u.metrics.Get(u.id(ref, nil)); ok { + if metrics, ok := u.metrics[u.id(ref, nil)]; ok { metrics.Errors = append(metrics.Errors, err) } @@ -114,7 +154,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) { } id := u.id(ref, chunkErr.Unit) - metrics, ok := u.metrics.Get(id) + metrics, ok := u.metrics[id] if !ok { return } @@ -122,51 +162,36 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) { } func (u *UnitHook) Finish(ref JobProgressRef) { - u.mu.Lock() - defer u.mu.Unlock() // Clear out any metrics on this job. This covers the case for the // source running without unit support. - prefix := u.id(ref, nil) - for _, id := range u.metrics.Keys() { - if !strings.HasPrefix(id, prefix) { - continue - } - metric, ok := u.metrics.Get(id) - if !ok { - continue - } - // If the unit is nil, the source does not support units. - // Use the overall job metrics instead. - if metric.Unit == nil { - snap := ref.Snapshot() - metric.StartTime = snap.StartTime - metric.EndTime = snap.EndTime - metric.Errors = snap.Errors - } + id := u.id(ref, nil) + metrics, ok := u.finishUnit(id) + if !ok { + return } + snap := ref.Snapshot() + metrics.StartTime = snap.StartTime + metrics.EndTime = snap.EndTime + metrics.Errors = snap.Errors + u.ejectFinishedMetrics(*metrics) } -// UnitMetrics gets all the currently active or newly finished metrics for this -// job. If a unit returned from this method has finished, it will be removed -// from the cache and no longer returned in successive calls to UnitMetrics(). -func (u *UnitHook) UnitMetrics() []UnitMetrics { +// InProgressSnapshot gets all the currently active metrics across all jobs. +func (u *UnitHook) InProgressSnapshot() []UnitMetrics { u.mu.Lock() defer u.mu.Unlock() - output := make([]UnitMetrics, 0, u.metrics.Len()) - for _, id := range u.metrics.Keys() { - metric, ok := u.metrics.Get(id) - if !ok { - continue - } - output = append(output, *metric) - if metric.IsFinished() { - metric.handled = true - u.metrics.Remove(id) - } + output := make([]UnitMetrics, 0, len(u.metrics)) + for _, metrics := range u.metrics { + output = append(output, *metrics) } return output } +func (u *UnitHook) Close() error { + close(u.finishedMetrics) + return nil +} + type UnitMetrics struct { Unit SourceUnit `json:"unit,omitempty"` Parent JobProgressRef `json:"parent,omitempty"` @@ -179,9 +204,6 @@ type UnitMetrics struct { TotalBytes uint64 `json:"total_bytes"` // All errors encountered by this unit. Errors []error `json:"errors"` - // Flag to mark that these metrics were intentionally evicted from - // the cache. - handled bool } func (u UnitMetrics) IsFinished() bool { diff --git a/pkg/sources/metrics.go b/pkg/sources/metrics.go index 158059da7..394771dc6 100644 --- a/pkg/sources/metrics.go +++ b/pkg/sources/metrics.go @@ -14,4 +14,11 @@ var ( Help: "Time spent executing hooks (ms)", Buckets: []float64{5, 50, 500, 1000}, }, nil) + + hooksChannelSize = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: common.MetricsNamespace, + Subsystem: common.MetricsSubsystem, + Name: "hooks_channel_size", + Help: "Total number of metrics waiting in the finished channel.", + }, nil) ) diff --git a/pkg/sources/source_manager.go b/pkg/sources/source_manager.go index 25722ccf0..0d7af851a 100644 --- a/pkg/sources/source_manager.go +++ b/pkg/sources/source_manager.go @@ -2,6 +2,7 @@ package sources import ( "fmt" + "io" "runtime" "sync" "sync/atomic" @@ -182,6 +183,11 @@ func (s *SourceManager) Wait() error { } close(s.outputChunks) close(s.firstErr) + for _, hook := range s.hooks { + if hookCloser, ok := hook.(io.Closer); ok { + _ = hookCloser.Close() + } + } return s.waitErr } diff --git a/pkg/sources/source_manager_test.go b/pkg/sources/source_manager_test.go index d6a6ee4f6..815f0507e 100644 --- a/pkg/sources/source_manager_test.go +++ b/pkg/sources/source_manager_test.go @@ -5,8 +5,8 @@ import ( "fmt" "sort" "testing" + "time" - lru "github.com/hashicorp/golang-lru/v2" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/anypb" @@ -316,7 +316,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) { } func TestSourceManagerUnitHook(t *testing.T) { - hook := NewUnitHook(context.TODO()) + hook, ch := NewUnitHook(context.TODO()) input := []unitChunk{ {unit: "1 one", output: "bar"}, @@ -333,9 +333,13 @@ func TestSourceManagerUnitHook(t *testing.T) { ref, err := mgr.Run(context.Background(), "dummy", source) assert.NoError(t, err) <-ref.Done() + assert.NoError(t, mgr.Wait()) - metrics := hook.UnitMetrics() - assert.Equal(t, 3, len(metrics)) + assert.Equal(t, 0, len(hook.InProgressSnapshot())) + var metrics []UnitMetrics + for metric := range ch { + metrics = append(metrics, metric) + } sort.Slice(metrics, func(i, j int) bool { return metrics[i].Unit.SourceUnitID() < metrics[j].Unit.SourceUnitID() }) @@ -366,14 +370,10 @@ func TestSourceManagerUnitHook(t *testing.T) { assert.Equal(t, 1, len(m2.Errors)) } -// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if -// they aren't handled fast enough. -func TestSourceManagerUnitHookNoBlock(t *testing.T) { - var evictedKeys []string - cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) { - evictedKeys = append(evictedKeys, key) - }) - hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache)) +// TestSourceManagerUnitHookBackPressure tests that the UnitHook blocks if the +// finished metrics aren't handled fast enough. +func TestSourceManagerUnitHookBackPressure(t *testing.T) { + hook, ch := NewUnitHook(context.TODO(), WithUnitHookFinishBufferSize(0)) input := []unitChunk{ {unit: "one", output: "bar"}, @@ -389,18 +389,25 @@ func TestSourceManagerUnitHookNoBlock(t *testing.T) { assert.NoError(t, err) ref, err := mgr.Run(context.Background(), "dummy", source) assert.NoError(t, err) - <-ref.Done() - assert.Equal(t, 2, len(evictedKeys)) - metrics := hook.UnitMetrics() - assert.Equal(t, 1, len(metrics)) - assert.Equal(t, "three", metrics[0].Unit.SourceUnitID()) + var metrics []UnitMetrics + for i := 0; i < len(input); i++ { + select { + case <-ref.Done(): + t.Fatal("job should not finish until metrics have been collected") + case <-time.After(1 * time.Millisecond): + } + metrics = append(metrics, <-ch) + } + + assert.NoError(t, mgr.Wait()) + assert.Equal(t, 3, len(metrics), metrics) } // TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for // sources that don't support units. func TestSourceManagerUnitHookNoUnits(t *testing.T) { - hook := NewUnitHook(context.TODO()) + hook, ch := NewUnitHook(context.TODO()) mgr := NewManager( WithBufferedOutput(8), @@ -412,8 +419,12 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) { ref, err := mgr.Run(context.Background(), "dummy", source) assert.NoError(t, err) <-ref.Done() + assert.NoError(t, mgr.Wait()) - metrics := hook.UnitMetrics() + var metrics []UnitMetrics + for metric := range ch { + metrics = append(metrics, metric) + } assert.Equal(t, 1, len(metrics)) m := metrics[0]