trufflehog/pkg/sources/source_manager_test.go
Miccah 0b16142d4f
Add UnitHook and NoopHook implementations (#1930)
* Add UnitHook and NoopHook implementations

The UnitHook tracks metrics per unit of a job, and emits them on a
channel once finished. It should work even if the Source does not
support source units.

* Refactor channel to use an LRU cache instead

An LRU cache has a more favorable failure mode than the channel. With
the channel, if the consumer stopped consuming metrics, scanning would
block. With the LRU cache, metrics will be dropped when space runs out
and a log message emitted.
2023-10-23 14:27:01 -07:00

427 lines
13 KiB
Go

package sources
import (
"errors"
"fmt"
"sort"
"testing"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
)
// DummySource implements Source and is used for testing a SourceManager.
type DummySource struct {
sourceID SourceID
jobID JobID
chunker
}
func (d *DummySource) Type() sourcespb.SourceType { return 1337 }
func (d *DummySource) SourceID() SourceID { return d.sourceID }
func (d *DummySource) JobID() JobID { return d.jobID }
func (d *DummySource) Init(_ context.Context, _ string, jobID JobID, sourceID SourceID, _ bool, _ *anypb.Any, _ int) error {
d.sourceID = sourceID
d.jobID = jobID
return nil
}
func (d *DummySource) GetProgress() *Progress { return nil }
// Interface to easily test different chunking methods.
type chunker interface {
Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error
ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error
Enumerate(ctx context.Context, reporter UnitReporter) error
}
// Chunk method that writes count bytes to the channel before returning.
type counterChunker struct {
chunkCounter byte
count int
}
func (c *counterChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
for i := 0; i < c.count; i++ {
select {
case ch <- &Chunk{Data: []byte{c.chunkCounter}}:
c.chunkCounter++
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// countChunk implements SourceUnit.
type countChunk byte
func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) }
func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error {
for i := 0; i < c.count; i++ {
if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil {
return err
}
}
return nil
}
func (c *counterChunker) ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error {
return reporter.ChunkOk(ctx, Chunk{Data: []byte{byte(unit.(countChunk))}})
}
// Chunk method that always returns an error.
type errorChunker struct{ error }
func (c errorChunker) Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error { return c }
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
// buildDummy is a helper function to enroll a DummySource with a SourceManager.
func buildDummy(chunkMethod chunker) (Source, error) {
source := &DummySource{chunker: chunkMethod}
if err := source.Init(context.Background(), "dummy", 123, 456, true, nil, 42); err != nil {
return nil, err
}
return source, nil
}
// tryRead is a helper function that will try to read from a channel and return
// an error if it cannot.
func tryRead(ch <-chan *Chunk) (*Chunk, error) {
select {
case chunk := <-ch:
return chunk, nil
default:
return nil, fmt.Errorf("no chunk available")
}
}
func TestSourceManagerRun(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8))
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
for i := 0; i < 3; i++ {
ref, err := mgr.Run(context.Background(), "dummy", source)
<-ref.Done()
assert.NoError(t, err)
assert.NoError(t, ref.Snapshot().FatalError())
chunk, err := tryRead(mgr.Chunks())
assert.NoError(t, err)
assert.Equal(t, []byte{byte(i)}, chunk.Data)
// The Chunks channel should be empty now.
_, err = tryRead(mgr.Chunks())
assert.Error(t, err)
}
}
func TestSourceManagerWait(t *testing.T) {
mgr := NewManager()
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
// Asynchronously run the source.
_, err = mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
// Read the 1 chunk we're expecting so Waiting completes.
<-mgr.Chunks()
// Wait for all resources to complete.
assert.NoError(t, mgr.Wait())
// Run should return an error now.
_, err = buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
_, err = mgr.Run(context.Background(), "dummy", source)
assert.Error(t, err)
}
func TestSourceManagerError(t *testing.T) {
mgr := NewManager()
source, err := buildDummy(errorChunker{fmt.Errorf("oops")})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
assert.Error(t, mgr.Wait())
}
func TestSourceManagerReport(t *testing.T) {
for _, opts := range [][]func(*SourceManager){
{WithBufferedOutput(8)},
{WithBufferedOutput(8), WithSourceUnits()},
{WithBufferedOutput(8), WithSourceUnits(), WithConcurrentUnits(1)},
} {
mgr := NewManager(opts...)
source, err := buildDummy(&counterChunker{count: 4})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Equal(t, 0, len(ref.Snapshot().Errors))
assert.Equal(t, uint64(4), ref.Snapshot().TotalChunks)
}
}
type unitChunk struct {
unit string
output string
err string
}
type unitChunker struct{ steps []unitChunk }
func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
for _, step := range c.steps {
if step.err != "" {
continue
}
if err := common.CancellableWrite(ctx, ch, &Chunk{Data: []byte(step.output)}); err != nil {
return err
}
}
return nil
}
func (c *unitChunker) Enumerate(ctx context.Context, rep UnitReporter) error {
for _, step := range c.steps {
if err := rep.UnitOk(ctx, CommonSourceUnit{step.unit}); err != nil {
return err
}
}
return nil
}
func (c *unitChunker) ChunkUnit(ctx context.Context, unit SourceUnit, rep ChunkReporter) error {
for _, step := range c.steps {
if unit.SourceUnitID() != step.unit {
continue
}
if step.err != "" {
if err := rep.ChunkErr(ctx, fmt.Errorf(step.err)); err != nil {
return err
}
}
if step.output == "" {
continue
}
if err := rep.ChunkOk(ctx, Chunk{Data: []byte(step.output)}); err != nil {
return err
}
}
return nil
}
func TestSourceManagerNonFatalError(t *testing.T) {
input := []unitChunk{
{unit: "one", output: "bar"},
{unit: "two", err: "oh no"},
{unit: "three", err: "not again"},
}
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
report := ref.Snapshot()
assert.Equal(t, len(input), int(report.TotalUnits))
assert.Equal(t, len(input), int(report.FinishedUnits))
assert.Equal(t, 1, int(report.TotalChunks))
assert.Equal(t, 2, len(report.Errors))
assert.True(t, report.DoneEnumerating)
}
func TestSourceManagerContextCancelled(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8))
source, err := buildDummy(&counterChunker{count: 100})
assert.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
ref, err := mgr.Run(ctx, "dummy", source)
assert.NoError(t, err)
cancel()
<-ref.Done()
report := ref.Snapshot()
assert.Error(t, report.FatalError())
}
type DummyAPI struct {
registerSource func(context.Context, string, sourcespb.SourceType) (SourceID, error)
getJobID func(context.Context, SourceID) (JobID, error)
}
func (api DummyAPI) RegisterSource(ctx context.Context, name string, kind sourcespb.SourceType) (SourceID, error) {
return api.registerSource(ctx, name, kind)
}
func (api DummyAPI) GetJobID(ctx context.Context, id SourceID) (JobID, error) {
return api.getJobID(ctx, id)
}
// Chunk method that has a custom callback for the Chunks method.
type callbackChunker struct {
cb func(context.Context, chan *Chunk) error
}
func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
return c.cb(ctx, ch)
}
func (c callbackChunker) Enumerate(context.Context, UnitReporter) error { return nil }
func (c callbackChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return nil }
func TestSourceManagerCancelRun(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8))
var returnedErr error
source, err := buildDummy(callbackChunker{func(ctx context.Context, _ chan *Chunk) error {
// The context passed to Chunks should get cancelled when ref.CancelRun() is called.
<-ctx.Done()
returnedErr = fmt.Errorf("oh no: %w", ctx.Err())
return returnedErr
}})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
cancelErr := fmt.Errorf("abort! abort!")
ref.CancelRun(cancelErr)
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
assert.True(t, errors.Is(ref.Snapshot().FatalError(), returnedErr))
assert.True(t, errors.Is(ref.Snapshot().FatalErrors(), cancelErr))
}
func TestSourceManagerAvailableCapacity(t *testing.T) {
mgr := NewManager(WithConcurrentSources(1337))
start, end := make(chan struct{}), make(chan struct{})
source, err := buildDummy(callbackChunker{func(context.Context, chan *Chunk) error {
start <- struct{}{} // Send start signal.
<-end // Wait for end signal.
return nil
}})
assert.NoError(t, err)
assert.Equal(t, 1337, mgr.AvailableCapacity())
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-start // Wait for start signal.
assert.Equal(t, 1336, mgr.AvailableCapacity())
end <- struct{}{} // Send end signal.
<-ref.Done() // Wait for the job to finish.
assert.Equal(t, 1337, mgr.AvailableCapacity())
}
func TestSourceManagerUnitHook(t *testing.T) {
hook := NewUnitHook(context.TODO())
input := []unitChunk{
{unit: "one", output: "bar"},
{unit: "two", err: "oh no"},
{unit: "three", err: "not again"},
}
mgr := NewManager(
WithBufferedOutput(8),
WithSourceUnits(), WithConcurrentUnits(1),
WithReportHook(hook),
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
metrics := hook.UnitMetrics()
assert.Equal(t, 3, len(metrics))
sort.Slice(metrics, func(i, j int) bool {
return metrics[i].EndTime.Before(metrics[j].EndTime)
})
m0, m1, m2 := metrics[0], metrics[1], metrics[2]
assert.Equal(t, "one", m0.Unit.SourceUnitID())
assert.Equal(t, uint64(1), m0.TotalChunks)
assert.Equal(t, uint64(3), m0.TotalBytes)
assert.NotZero(t, m0.StartTime)
assert.NotZero(t, m0.EndTime)
assert.NotZero(t, m0.ElapsedTime())
assert.Equal(t, 0, len(m0.Errors))
assert.Equal(t, "two", m1.Unit.SourceUnitID())
assert.Equal(t, uint64(0), m1.TotalChunks)
assert.Equal(t, uint64(0), m1.TotalBytes)
assert.NotZero(t, m1.StartTime)
assert.NotZero(t, m1.EndTime)
assert.NotZero(t, m1.ElapsedTime())
assert.Equal(t, 1, len(m1.Errors))
assert.Equal(t, "three", m2.Unit.SourceUnitID())
assert.Equal(t, uint64(0), m2.TotalChunks)
assert.Equal(t, uint64(0), m2.TotalBytes)
assert.NotZero(t, m2.StartTime)
assert.NotZero(t, m2.EndTime)
assert.NotZero(t, m2.ElapsedTime())
assert.Equal(t, 1, len(m2.Errors))
}
// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if
// they aren't handled fast enough.
func TestSourceManagerUnitHookNoBlock(t *testing.T) {
var evictedKeys []string
cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) {
evictedKeys = append(evictedKeys, key)
})
hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache))
input := []unitChunk{
{unit: "one", output: "bar"},
{unit: "two", err: "oh no"},
{unit: "three", err: "not again"},
}
mgr := NewManager(
WithBufferedOutput(8),
WithSourceUnits(), WithConcurrentUnits(1),
WithReportHook(hook),
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Equal(t, 2, len(evictedKeys))
metrics := hook.UnitMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "three", metrics[0].Unit.SourceUnitID())
}
// TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for
// sources that don't support units.
func TestSourceManagerUnitHookNoUnits(t *testing.T) {
hook := NewUnitHook(context.TODO())
mgr := NewManager(
WithBufferedOutput(8),
WithReportHook(hook),
)
source, err := buildDummy(&counterChunker{count: 5})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
metrics := hook.UnitMetrics()
assert.Equal(t, 1, len(metrics))
m := metrics[0]
assert.Equal(t, nil, m.Unit)
assert.Equal(t, uint64(5), m.TotalChunks)
assert.Equal(t, uint64(5), m.TotalBytes)
assert.NotZero(t, m.StartTime)
assert.NotZero(t, m.EndTime)
assert.NotZero(t, m.ElapsedTime())
assert.Equal(t, 0, len(m.Errors))
}