Refactor UnitHook to block the scan if finished metrics aren't handled (#2309)

* Refactor UnitHook to block the scan if finished metrics aren't handled

* Log once when back-pressure is detected

* Add hook channel size metric

* Use plural "metrics" for consistency

* Replace LRU cache with map
This commit is contained in:
Miccah 2024-02-08 14:50:58 -08:00 committed by GitHub
parent adc09c0533
commit dd4d4a8a96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 132 additions and 85 deletions

View file

@ -188,7 +188,8 @@ func (jp *JobProgress) executeHooks(todo func(hook JobProgressHook)) {
hooksExecTime.WithLabelValues().Observe(float64(elapsed))
}(time.Now())
for _, hook := range jp.hooks {
// TODO: Non-blocking?
// Execute hooks synchronously so they can provide
// back-pressure to the source.
todo(hook)
}
}

View file

@ -3,44 +3,61 @@ package sources
import (
"errors"
"fmt"
"strings"
"runtime"
"sync"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
// UnitHook implements JobProgressHook for tracking the progress of each
// individual unit.
type UnitHook struct {
metrics *lru.Cache[string, *UnitMetrics]
mu sync.Mutex
metrics map[string]*UnitMetrics
mu sync.Mutex
finishedMetrics chan UnitMetrics
logBackPressure func()
NoopHook
}
type UnitHookOpt func(*UnitHook)
func WithUnitHookCache(cache *lru.Cache[string, *UnitMetrics]) UnitHookOpt {
return func(hook *UnitHook) { hook.metrics = cache }
// WithUnitHookFinishBufferSize sets the buffer size for handling finished
// metrics (default is 1024). If the buffer fills, then scanning will stop
// until there is room.
func WithUnitHookFinishBufferSize(buf int) UnitHookOpt {
return func(hook *UnitHook) {
hook.finishedMetrics = make(chan UnitMetrics, buf)
}
}
func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) *UnitHook {
// lru.NewWithEvict can only fail if the size is < 0.
cache, _ := lru.NewWithEvict(1024, func(key string, value *UnitMetrics) {
if value.handled {
return
}
ctx.Logger().Error(fmt.Errorf("eviction"), "dropping unit metric",
"id", key,
"metric", value,
)
})
hook := UnitHook{metrics: cache}
func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) (*UnitHook, <-chan UnitMetrics) {
var once sync.Once
hook := UnitHook{
metrics: make(map[string]*UnitMetrics, runtime.NumCPU()),
finishedMetrics: make(chan UnitMetrics, 1024),
logBackPressure: func() {
once.Do(func() {
ctx.Logger().Info("back pressure detected in unit hook")
})
},
}
for _, opt := range opts {
opt(&hook)
}
return &hook
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hooksChannelSize.WithLabelValues().Set(float64(len(hook.finishedMetrics)))
case <-ctx.Done():
return
}
}
}()
return &hook, hook.finishedMetrics
}
// id is a helper method to generate an ID for the given job and unit.
@ -52,28 +69,51 @@ func (u *UnitHook) id(ref JobProgressRef, unit SourceUnit) string {
return fmt.Sprintf("%d/%d/%s", ref.SourceID, ref.JobID, unitID)
}
func (u *UnitHook) ejectFinishedMetrics(metrics UnitMetrics) {
// Intentionally block the hook from returning to supply back-pressure
// to the source.
select {
case u.finishedMetrics <- metrics:
return
default:
u.logBackPressure()
}
u.finishedMetrics <- metrics
}
func (u *UnitHook) StartUnitChunking(ref JobProgressRef, unit SourceUnit, start time.Time) {
id := u.id(ref, unit)
u.mu.Lock()
defer u.mu.Unlock()
u.metrics.Add(id, &UnitMetrics{
u.metrics[id] = &UnitMetrics{
Unit: unit,
Parent: ref,
StartTime: &start,
})
}
}
func (u *UnitHook) EndUnitChunking(ref JobProgressRef, unit SourceUnit, end time.Time) {
id := u.id(ref, unit)
u.mu.Lock()
defer u.mu.Unlock()
metrics, ok := u.metrics.Get(id)
metrics, ok := u.finishUnit(id)
if !ok {
return
}
metrics.EndTime = &end
u.ejectFinishedMetrics(*metrics)
}
func (u *UnitHook) finishUnit(id string) (*UnitMetrics, bool) {
u.mu.Lock()
defer u.mu.Unlock()
metrics, ok := u.metrics[id]
if !ok {
return nil, false
}
delete(u.metrics, id)
return metrics, true
}
func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk) {
@ -81,7 +121,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk
u.mu.Lock()
defer u.mu.Unlock()
metrics, ok := u.metrics.Get(id)
metrics, ok := u.metrics[id]
if !ok && unit != nil {
// The unit has been evicted.
return
@ -92,7 +132,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk
Parent: ref,
StartTime: ref.Snapshot().StartTime,
}
u.metrics.Add(id, metrics)
u.metrics[id] = metrics
}
metrics.TotalChunks++
metrics.TotalBytes += uint64(len(chunk.Data))
@ -103,7 +143,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
defer u.mu.Unlock()
// Always add the error to the nil unit if it exists.
if metrics, ok := u.metrics.Get(u.id(ref, nil)); ok {
if metrics, ok := u.metrics[u.id(ref, nil)]; ok {
metrics.Errors = append(metrics.Errors, err)
}
@ -114,7 +154,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
}
id := u.id(ref, chunkErr.Unit)
metrics, ok := u.metrics.Get(id)
metrics, ok := u.metrics[id]
if !ok {
return
}
@ -122,51 +162,36 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
}
func (u *UnitHook) Finish(ref JobProgressRef) {
u.mu.Lock()
defer u.mu.Unlock()
// Clear out any metrics on this job. This covers the case for the
// source running without unit support.
prefix := u.id(ref, nil)
for _, id := range u.metrics.Keys() {
if !strings.HasPrefix(id, prefix) {
continue
}
metric, ok := u.metrics.Get(id)
if !ok {
continue
}
// If the unit is nil, the source does not support units.
// Use the overall job metrics instead.
if metric.Unit == nil {
snap := ref.Snapshot()
metric.StartTime = snap.StartTime
metric.EndTime = snap.EndTime
metric.Errors = snap.Errors
}
id := u.id(ref, nil)
metrics, ok := u.finishUnit(id)
if !ok {
return
}
snap := ref.Snapshot()
metrics.StartTime = snap.StartTime
metrics.EndTime = snap.EndTime
metrics.Errors = snap.Errors
u.ejectFinishedMetrics(*metrics)
}
// UnitMetrics gets all the currently active or newly finished metrics for this
// job. If a unit returned from this method has finished, it will be removed
// from the cache and no longer returned in successive calls to UnitMetrics().
func (u *UnitHook) UnitMetrics() []UnitMetrics {
// InProgressSnapshot gets all the currently active metrics across all jobs.
func (u *UnitHook) InProgressSnapshot() []UnitMetrics {
u.mu.Lock()
defer u.mu.Unlock()
output := make([]UnitMetrics, 0, u.metrics.Len())
for _, id := range u.metrics.Keys() {
metric, ok := u.metrics.Get(id)
if !ok {
continue
}
output = append(output, *metric)
if metric.IsFinished() {
metric.handled = true
u.metrics.Remove(id)
}
output := make([]UnitMetrics, 0, len(u.metrics))
for _, metrics := range u.metrics {
output = append(output, *metrics)
}
return output
}
func (u *UnitHook) Close() error {
close(u.finishedMetrics)
return nil
}
type UnitMetrics struct {
Unit SourceUnit `json:"unit,omitempty"`
Parent JobProgressRef `json:"parent,omitempty"`
@ -179,9 +204,6 @@ type UnitMetrics struct {
TotalBytes uint64 `json:"total_bytes"`
// All errors encountered by this unit.
Errors []error `json:"errors"`
// Flag to mark that these metrics were intentionally evicted from
// the cache.
handled bool
}
func (u UnitMetrics) IsFinished() bool {

View file

@ -14,4 +14,11 @@ var (
Help: "Time spent executing hooks (ms)",
Buckets: []float64{5, 50, 500, 1000},
}, nil)
hooksChannelSize = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "hooks_channel_size",
Help: "Total number of metrics waiting in the finished channel.",
}, nil)
)

View file

@ -2,6 +2,7 @@ package sources
import (
"fmt"
"io"
"runtime"
"sync"
"sync/atomic"
@ -182,6 +183,11 @@ func (s *SourceManager) Wait() error {
}
close(s.outputChunks)
close(s.firstErr)
for _, hook := range s.hooks {
if hookCloser, ok := hook.(io.Closer); ok {
_ = hookCloser.Close()
}
}
return s.waitErr
}

View file

@ -5,8 +5,8 @@ import (
"fmt"
"sort"
"testing"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
@ -316,7 +316,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) {
}
func TestSourceManagerUnitHook(t *testing.T) {
hook := NewUnitHook(context.TODO())
hook, ch := NewUnitHook(context.TODO())
input := []unitChunk{
{unit: "1 one", output: "bar"},
@ -333,9 +333,13 @@ func TestSourceManagerUnitHook(t *testing.T) {
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
metrics := hook.UnitMetrics()
assert.Equal(t, 3, len(metrics))
assert.Equal(t, 0, len(hook.InProgressSnapshot()))
var metrics []UnitMetrics
for metric := range ch {
metrics = append(metrics, metric)
}
sort.Slice(metrics, func(i, j int) bool {
return metrics[i].Unit.SourceUnitID() < metrics[j].Unit.SourceUnitID()
})
@ -366,14 +370,10 @@ func TestSourceManagerUnitHook(t *testing.T) {
assert.Equal(t, 1, len(m2.Errors))
}
// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if
// they aren't handled fast enough.
func TestSourceManagerUnitHookNoBlock(t *testing.T) {
var evictedKeys []string
cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) {
evictedKeys = append(evictedKeys, key)
})
hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache))
// TestSourceManagerUnitHookBackPressure tests that the UnitHook blocks if the
// finished metrics aren't handled fast enough.
func TestSourceManagerUnitHookBackPressure(t *testing.T) {
hook, ch := NewUnitHook(context.TODO(), WithUnitHookFinishBufferSize(0))
input := []unitChunk{
{unit: "one", output: "bar"},
@ -389,18 +389,25 @@ func TestSourceManagerUnitHookNoBlock(t *testing.T) {
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Equal(t, 2, len(evictedKeys))
metrics := hook.UnitMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "three", metrics[0].Unit.SourceUnitID())
var metrics []UnitMetrics
for i := 0; i < len(input); i++ {
select {
case <-ref.Done():
t.Fatal("job should not finish until metrics have been collected")
case <-time.After(1 * time.Millisecond):
}
metrics = append(metrics, <-ch)
}
assert.NoError(t, mgr.Wait())
assert.Equal(t, 3, len(metrics), metrics)
}
// TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for
// sources that don't support units.
func TestSourceManagerUnitHookNoUnits(t *testing.T) {
hook := NewUnitHook(context.TODO())
hook, ch := NewUnitHook(context.TODO())
mgr := NewManager(
WithBufferedOutput(8),
@ -412,8 +419,12 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) {
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
metrics := hook.UnitMetrics()
var metrics []UnitMetrics
for metric := range ch {
metrics = append(metrics, metric)
}
assert.Equal(t, 1, len(metrics))
m := metrics[0]