mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Refactor UnitHook to block the scan if finished metrics aren't handled (#2309)
* Refactor UnitHook to block the scan if finished metrics aren't handled * Log once when back-pressure is detected * Add hook channel size metric * Use plural "metrics" for consistency * Replace LRU cache with map
This commit is contained in:
parent
adc09c0533
commit
dd4d4a8a96
5 changed files with 132 additions and 85 deletions
|
@ -188,7 +188,8 @@ func (jp *JobProgress) executeHooks(todo func(hook JobProgressHook)) {
|
|||
hooksExecTime.WithLabelValues().Observe(float64(elapsed))
|
||||
}(time.Now())
|
||||
for _, hook := range jp.hooks {
|
||||
// TODO: Non-blocking?
|
||||
// Execute hooks synchronously so they can provide
|
||||
// back-pressure to the source.
|
||||
todo(hook)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,44 +3,61 @@ package sources
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
)
|
||||
|
||||
// UnitHook implements JobProgressHook for tracking the progress of each
|
||||
// individual unit.
|
||||
type UnitHook struct {
|
||||
metrics *lru.Cache[string, *UnitMetrics]
|
||||
mu sync.Mutex
|
||||
metrics map[string]*UnitMetrics
|
||||
mu sync.Mutex
|
||||
finishedMetrics chan UnitMetrics
|
||||
logBackPressure func()
|
||||
NoopHook
|
||||
}
|
||||
|
||||
type UnitHookOpt func(*UnitHook)
|
||||
|
||||
func WithUnitHookCache(cache *lru.Cache[string, *UnitMetrics]) UnitHookOpt {
|
||||
return func(hook *UnitHook) { hook.metrics = cache }
|
||||
// WithUnitHookFinishBufferSize sets the buffer size for handling finished
|
||||
// metrics (default is 1024). If the buffer fills, then scanning will stop
|
||||
// until there is room.
|
||||
func WithUnitHookFinishBufferSize(buf int) UnitHookOpt {
|
||||
return func(hook *UnitHook) {
|
||||
hook.finishedMetrics = make(chan UnitMetrics, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) *UnitHook {
|
||||
// lru.NewWithEvict can only fail if the size is < 0.
|
||||
cache, _ := lru.NewWithEvict(1024, func(key string, value *UnitMetrics) {
|
||||
if value.handled {
|
||||
return
|
||||
}
|
||||
ctx.Logger().Error(fmt.Errorf("eviction"), "dropping unit metric",
|
||||
"id", key,
|
||||
"metric", value,
|
||||
)
|
||||
})
|
||||
hook := UnitHook{metrics: cache}
|
||||
func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) (*UnitHook, <-chan UnitMetrics) {
|
||||
var once sync.Once
|
||||
hook := UnitHook{
|
||||
metrics: make(map[string]*UnitMetrics, runtime.NumCPU()),
|
||||
finishedMetrics: make(chan UnitMetrics, 1024),
|
||||
logBackPressure: func() {
|
||||
once.Do(func() {
|
||||
ctx.Logger().Info("back pressure detected in unit hook")
|
||||
})
|
||||
},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&hook)
|
||||
}
|
||||
return &hook
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hooksChannelSize.WithLabelValues().Set(float64(len(hook.finishedMetrics)))
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &hook, hook.finishedMetrics
|
||||
}
|
||||
|
||||
// id is a helper method to generate an ID for the given job and unit.
|
||||
|
@ -52,28 +69,51 @@ func (u *UnitHook) id(ref JobProgressRef, unit SourceUnit) string {
|
|||
return fmt.Sprintf("%d/%d/%s", ref.SourceID, ref.JobID, unitID)
|
||||
}
|
||||
|
||||
func (u *UnitHook) ejectFinishedMetrics(metrics UnitMetrics) {
|
||||
// Intentionally block the hook from returning to supply back-pressure
|
||||
// to the source.
|
||||
select {
|
||||
case u.finishedMetrics <- metrics:
|
||||
return
|
||||
default:
|
||||
u.logBackPressure()
|
||||
}
|
||||
u.finishedMetrics <- metrics
|
||||
}
|
||||
|
||||
func (u *UnitHook) StartUnitChunking(ref JobProgressRef, unit SourceUnit, start time.Time) {
|
||||
id := u.id(ref, unit)
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
u.metrics.Add(id, &UnitMetrics{
|
||||
u.metrics[id] = &UnitMetrics{
|
||||
Unit: unit,
|
||||
Parent: ref,
|
||||
StartTime: &start,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UnitHook) EndUnitChunking(ref JobProgressRef, unit SourceUnit, end time.Time) {
|
||||
id := u.id(ref, unit)
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
metrics, ok := u.metrics.Get(id)
|
||||
metrics, ok := u.finishUnit(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
metrics.EndTime = &end
|
||||
u.ejectFinishedMetrics(*metrics)
|
||||
}
|
||||
|
||||
func (u *UnitHook) finishUnit(id string) (*UnitMetrics, bool) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
metrics, ok := u.metrics[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
delete(u.metrics, id)
|
||||
return metrics, true
|
||||
}
|
||||
|
||||
func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk) {
|
||||
|
@ -81,7 +121,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk
|
|||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
metrics, ok := u.metrics.Get(id)
|
||||
metrics, ok := u.metrics[id]
|
||||
if !ok && unit != nil {
|
||||
// The unit has been evicted.
|
||||
return
|
||||
|
@ -92,7 +132,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk
|
|||
Parent: ref,
|
||||
StartTime: ref.Snapshot().StartTime,
|
||||
}
|
||||
u.metrics.Add(id, metrics)
|
||||
u.metrics[id] = metrics
|
||||
}
|
||||
metrics.TotalChunks++
|
||||
metrics.TotalBytes += uint64(len(chunk.Data))
|
||||
|
@ -103,7 +143,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
|
|||
defer u.mu.Unlock()
|
||||
|
||||
// Always add the error to the nil unit if it exists.
|
||||
if metrics, ok := u.metrics.Get(u.id(ref, nil)); ok {
|
||||
if metrics, ok := u.metrics[u.id(ref, nil)]; ok {
|
||||
metrics.Errors = append(metrics.Errors, err)
|
||||
}
|
||||
|
||||
|
@ -114,7 +154,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
|
|||
}
|
||||
id := u.id(ref, chunkErr.Unit)
|
||||
|
||||
metrics, ok := u.metrics.Get(id)
|
||||
metrics, ok := u.metrics[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -122,51 +162,36 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
|
|||
}
|
||||
|
||||
func (u *UnitHook) Finish(ref JobProgressRef) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
// Clear out any metrics on this job. This covers the case for the
|
||||
// source running without unit support.
|
||||
prefix := u.id(ref, nil)
|
||||
for _, id := range u.metrics.Keys() {
|
||||
if !strings.HasPrefix(id, prefix) {
|
||||
continue
|
||||
}
|
||||
metric, ok := u.metrics.Get(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// If the unit is nil, the source does not support units.
|
||||
// Use the overall job metrics instead.
|
||||
if metric.Unit == nil {
|
||||
snap := ref.Snapshot()
|
||||
metric.StartTime = snap.StartTime
|
||||
metric.EndTime = snap.EndTime
|
||||
metric.Errors = snap.Errors
|
||||
}
|
||||
id := u.id(ref, nil)
|
||||
metrics, ok := u.finishUnit(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
snap := ref.Snapshot()
|
||||
metrics.StartTime = snap.StartTime
|
||||
metrics.EndTime = snap.EndTime
|
||||
metrics.Errors = snap.Errors
|
||||
u.ejectFinishedMetrics(*metrics)
|
||||
}
|
||||
|
||||
// UnitMetrics gets all the currently active or newly finished metrics for this
|
||||
// job. If a unit returned from this method has finished, it will be removed
|
||||
// from the cache and no longer returned in successive calls to UnitMetrics().
|
||||
func (u *UnitHook) UnitMetrics() []UnitMetrics {
|
||||
// InProgressSnapshot gets all the currently active metrics across all jobs.
|
||||
func (u *UnitHook) InProgressSnapshot() []UnitMetrics {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
output := make([]UnitMetrics, 0, u.metrics.Len())
|
||||
for _, id := range u.metrics.Keys() {
|
||||
metric, ok := u.metrics.Get(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
output = append(output, *metric)
|
||||
if metric.IsFinished() {
|
||||
metric.handled = true
|
||||
u.metrics.Remove(id)
|
||||
}
|
||||
output := make([]UnitMetrics, 0, len(u.metrics))
|
||||
for _, metrics := range u.metrics {
|
||||
output = append(output, *metrics)
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func (u *UnitHook) Close() error {
|
||||
close(u.finishedMetrics)
|
||||
return nil
|
||||
}
|
||||
|
||||
type UnitMetrics struct {
|
||||
Unit SourceUnit `json:"unit,omitempty"`
|
||||
Parent JobProgressRef `json:"parent,omitempty"`
|
||||
|
@ -179,9 +204,6 @@ type UnitMetrics struct {
|
|||
TotalBytes uint64 `json:"total_bytes"`
|
||||
// All errors encountered by this unit.
|
||||
Errors []error `json:"errors"`
|
||||
// Flag to mark that these metrics were intentionally evicted from
|
||||
// the cache.
|
||||
handled bool
|
||||
}
|
||||
|
||||
func (u UnitMetrics) IsFinished() bool {
|
||||
|
|
|
@ -14,4 +14,11 @@ var (
|
|||
Help: "Time spent executing hooks (ms)",
|
||||
Buckets: []float64{5, 50, 500, 1000},
|
||||
}, nil)
|
||||
|
||||
hooksChannelSize = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: common.MetricsNamespace,
|
||||
Subsystem: common.MetricsSubsystem,
|
||||
Name: "hooks_channel_size",
|
||||
Help: "Total number of metrics waiting in the finished channel.",
|
||||
}, nil)
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ package sources
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -182,6 +183,11 @@ func (s *SourceManager) Wait() error {
|
|||
}
|
||||
close(s.outputChunks)
|
||||
close(s.firstErr)
|
||||
for _, hook := range s.hooks {
|
||||
if hookCloser, ok := hook.(io.Closer); ok {
|
||||
_ = hookCloser.Close()
|
||||
}
|
||||
}
|
||||
return s.waitErr
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
|
@ -316,7 +316,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSourceManagerUnitHook(t *testing.T) {
|
||||
hook := NewUnitHook(context.TODO())
|
||||
hook, ch := NewUnitHook(context.TODO())
|
||||
|
||||
input := []unitChunk{
|
||||
{unit: "1 one", output: "bar"},
|
||||
|
@ -333,9 +333,13 @@ func TestSourceManagerUnitHook(t *testing.T) {
|
|||
ref, err := mgr.Run(context.Background(), "dummy", source)
|
||||
assert.NoError(t, err)
|
||||
<-ref.Done()
|
||||
assert.NoError(t, mgr.Wait())
|
||||
|
||||
metrics := hook.UnitMetrics()
|
||||
assert.Equal(t, 3, len(metrics))
|
||||
assert.Equal(t, 0, len(hook.InProgressSnapshot()))
|
||||
var metrics []UnitMetrics
|
||||
for metric := range ch {
|
||||
metrics = append(metrics, metric)
|
||||
}
|
||||
sort.Slice(metrics, func(i, j int) bool {
|
||||
return metrics[i].Unit.SourceUnitID() < metrics[j].Unit.SourceUnitID()
|
||||
})
|
||||
|
@ -366,14 +370,10 @@ func TestSourceManagerUnitHook(t *testing.T) {
|
|||
assert.Equal(t, 1, len(m2.Errors))
|
||||
}
|
||||
|
||||
// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if
|
||||
// they aren't handled fast enough.
|
||||
func TestSourceManagerUnitHookNoBlock(t *testing.T) {
|
||||
var evictedKeys []string
|
||||
cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) {
|
||||
evictedKeys = append(evictedKeys, key)
|
||||
})
|
||||
hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache))
|
||||
// TestSourceManagerUnitHookBackPressure tests that the UnitHook blocks if the
|
||||
// finished metrics aren't handled fast enough.
|
||||
func TestSourceManagerUnitHookBackPressure(t *testing.T) {
|
||||
hook, ch := NewUnitHook(context.TODO(), WithUnitHookFinishBufferSize(0))
|
||||
|
||||
input := []unitChunk{
|
||||
{unit: "one", output: "bar"},
|
||||
|
@ -389,18 +389,25 @@ func TestSourceManagerUnitHookNoBlock(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
ref, err := mgr.Run(context.Background(), "dummy", source)
|
||||
assert.NoError(t, err)
|
||||
<-ref.Done()
|
||||
|
||||
assert.Equal(t, 2, len(evictedKeys))
|
||||
metrics := hook.UnitMetrics()
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
assert.Equal(t, "three", metrics[0].Unit.SourceUnitID())
|
||||
var metrics []UnitMetrics
|
||||
for i := 0; i < len(input); i++ {
|
||||
select {
|
||||
case <-ref.Done():
|
||||
t.Fatal("job should not finish until metrics have been collected")
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
}
|
||||
metrics = append(metrics, <-ch)
|
||||
}
|
||||
|
||||
assert.NoError(t, mgr.Wait())
|
||||
assert.Equal(t, 3, len(metrics), metrics)
|
||||
}
|
||||
|
||||
// TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for
|
||||
// sources that don't support units.
|
||||
func TestSourceManagerUnitHookNoUnits(t *testing.T) {
|
||||
hook := NewUnitHook(context.TODO())
|
||||
hook, ch := NewUnitHook(context.TODO())
|
||||
|
||||
mgr := NewManager(
|
||||
WithBufferedOutput(8),
|
||||
|
@ -412,8 +419,12 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) {
|
|||
ref, err := mgr.Run(context.Background(), "dummy", source)
|
||||
assert.NoError(t, err)
|
||||
<-ref.Done()
|
||||
assert.NoError(t, mgr.Wait())
|
||||
|
||||
metrics := hook.UnitMetrics()
|
||||
var metrics []UnitMetrics
|
||||
for metric := range ch {
|
||||
metrics = append(metrics, metric)
|
||||
}
|
||||
assert.Equal(t, 1, len(metrics))
|
||||
|
||||
m := metrics[0]
|
||||
|
|
Loading…
Reference in a new issue