mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 15:14:38 +00:00
0b16142d4f
* Add UnitHook and NoopHook implementations The UnitHook tracks metrics per unit of a job, and emits them on a channel once finished. It should work even if the Source does not support source units. * Refactor channel to use an LRU cache instead An LRU cache has a more favorable failure mode than the channel. With the channel, if the consumer stopped consuming metrics, scanning would block. With the LRU cache, metrics will be dropped when space runs out and a log message emitted.
427 lines
13 KiB
Go
427 lines
13 KiB
Go
package sources
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"testing"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
|
|
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
|
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
|
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
|
)
|
|
|
|
// DummySource implements Source and is used for testing a SourceManager.
|
|
type DummySource struct {
|
|
sourceID SourceID
|
|
jobID JobID
|
|
chunker
|
|
}
|
|
|
|
func (d *DummySource) Type() sourcespb.SourceType { return 1337 }
|
|
func (d *DummySource) SourceID() SourceID { return d.sourceID }
|
|
func (d *DummySource) JobID() JobID { return d.jobID }
|
|
func (d *DummySource) Init(_ context.Context, _ string, jobID JobID, sourceID SourceID, _ bool, _ *anypb.Any, _ int) error {
|
|
d.sourceID = sourceID
|
|
d.jobID = jobID
|
|
return nil
|
|
}
|
|
func (d *DummySource) GetProgress() *Progress { return nil }
|
|
|
|
// Interface to easily test different chunking methods.
|
|
type chunker interface {
|
|
Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error
|
|
ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error
|
|
Enumerate(ctx context.Context, reporter UnitReporter) error
|
|
}
|
|
|
|
// Chunk method that writes count bytes to the channel before returning.
|
|
type counterChunker struct {
|
|
chunkCounter byte
|
|
count int
|
|
}
|
|
|
|
func (c *counterChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
|
|
for i := 0; i < c.count; i++ {
|
|
select {
|
|
case ch <- &Chunk{Data: []byte{c.chunkCounter}}:
|
|
c.chunkCounter++
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// countChunk implements SourceUnit.
|
|
type countChunk byte
|
|
|
|
func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) }
|
|
|
|
func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error {
|
|
for i := 0; i < c.count; i++ {
|
|
if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *counterChunker) ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error {
|
|
return reporter.ChunkOk(ctx, Chunk{Data: []byte{byte(unit.(countChunk))}})
|
|
}
|
|
|
|
// Chunk method that always returns an error.
|
|
type errorChunker struct{ error }
|
|
|
|
func (c errorChunker) Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error { return c }
|
|
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
|
|
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
|
|
|
|
// buildDummy is a helper function to enroll a DummySource with a SourceManager.
|
|
func buildDummy(chunkMethod chunker) (Source, error) {
|
|
source := &DummySource{chunker: chunkMethod}
|
|
if err := source.Init(context.Background(), "dummy", 123, 456, true, nil, 42); err != nil {
|
|
return nil, err
|
|
}
|
|
return source, nil
|
|
}
|
|
|
|
// tryRead is a helper function that will try to read from a channel and return
|
|
// an error if it cannot.
|
|
func tryRead(ch <-chan *Chunk) (*Chunk, error) {
|
|
select {
|
|
case chunk := <-ch:
|
|
return chunk, nil
|
|
default:
|
|
return nil, fmt.Errorf("no chunk available")
|
|
}
|
|
}
|
|
|
|
func TestSourceManagerRun(t *testing.T) {
|
|
mgr := NewManager(WithBufferedOutput(8))
|
|
source, err := buildDummy(&counterChunker{count: 1})
|
|
assert.NoError(t, err)
|
|
for i := 0; i < 3; i++ {
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
<-ref.Done()
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, ref.Snapshot().FatalError())
|
|
chunk, err := tryRead(mgr.Chunks())
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte{byte(i)}, chunk.Data)
|
|
// The Chunks channel should be empty now.
|
|
_, err = tryRead(mgr.Chunks())
|
|
assert.Error(t, err)
|
|
}
|
|
}
|
|
|
|
func TestSourceManagerWait(t *testing.T) {
|
|
mgr := NewManager()
|
|
source, err := buildDummy(&counterChunker{count: 1})
|
|
assert.NoError(t, err)
|
|
// Asynchronously run the source.
|
|
_, err = mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
// Read the 1 chunk we're expecting so Waiting completes.
|
|
<-mgr.Chunks()
|
|
// Wait for all resources to complete.
|
|
assert.NoError(t, mgr.Wait())
|
|
// Run should return an error now.
|
|
_, err = buildDummy(&counterChunker{count: 1})
|
|
assert.NoError(t, err)
|
|
_, err = mgr.Run(context.Background(), "dummy", source)
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestSourceManagerError(t *testing.T) {
|
|
mgr := NewManager()
|
|
source, err := buildDummy(errorChunker{fmt.Errorf("oops")})
|
|
assert.NoError(t, err)
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
assert.Error(t, ref.Snapshot().FatalError())
|
|
assert.Error(t, mgr.Wait())
|
|
}
|
|
|
|
func TestSourceManagerReport(t *testing.T) {
|
|
for _, opts := range [][]func(*SourceManager){
|
|
{WithBufferedOutput(8)},
|
|
{WithBufferedOutput(8), WithSourceUnits()},
|
|
{WithBufferedOutput(8), WithSourceUnits(), WithConcurrentUnits(1)},
|
|
} {
|
|
mgr := NewManager(opts...)
|
|
source, err := buildDummy(&counterChunker{count: 4})
|
|
assert.NoError(t, err)
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
assert.Equal(t, 0, len(ref.Snapshot().Errors))
|
|
assert.Equal(t, uint64(4), ref.Snapshot().TotalChunks)
|
|
}
|
|
}
|
|
|
|
type unitChunk struct {
|
|
unit string
|
|
output string
|
|
err string
|
|
}
|
|
|
|
type unitChunker struct{ steps []unitChunk }
|
|
|
|
func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
|
|
for _, step := range c.steps {
|
|
if step.err != "" {
|
|
continue
|
|
}
|
|
if err := common.CancellableWrite(ctx, ch, &Chunk{Data: []byte(step.output)}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
func (c *unitChunker) Enumerate(ctx context.Context, rep UnitReporter) error {
|
|
for _, step := range c.steps {
|
|
if err := rep.UnitOk(ctx, CommonSourceUnit{step.unit}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
func (c *unitChunker) ChunkUnit(ctx context.Context, unit SourceUnit, rep ChunkReporter) error {
|
|
for _, step := range c.steps {
|
|
if unit.SourceUnitID() != step.unit {
|
|
continue
|
|
}
|
|
if step.err != "" {
|
|
if err := rep.ChunkErr(ctx, fmt.Errorf(step.err)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if step.output == "" {
|
|
continue
|
|
}
|
|
if err := rep.ChunkOk(ctx, Chunk{Data: []byte(step.output)}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestSourceManagerNonFatalError(t *testing.T) {
|
|
input := []unitChunk{
|
|
{unit: "one", output: "bar"},
|
|
{unit: "two", err: "oh no"},
|
|
{unit: "three", err: "not again"},
|
|
}
|
|
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
|
|
source, err := buildDummy(&unitChunker{input})
|
|
assert.NoError(t, err)
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
report := ref.Snapshot()
|
|
assert.Equal(t, len(input), int(report.TotalUnits))
|
|
assert.Equal(t, len(input), int(report.FinishedUnits))
|
|
assert.Equal(t, 1, int(report.TotalChunks))
|
|
assert.Equal(t, 2, len(report.Errors))
|
|
assert.True(t, report.DoneEnumerating)
|
|
}
|
|
|
|
func TestSourceManagerContextCancelled(t *testing.T) {
|
|
mgr := NewManager(WithBufferedOutput(8))
|
|
source, err := buildDummy(&counterChunker{count: 100})
|
|
assert.NoError(t, err)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ref, err := mgr.Run(ctx, "dummy", source)
|
|
assert.NoError(t, err)
|
|
|
|
cancel()
|
|
<-ref.Done()
|
|
report := ref.Snapshot()
|
|
assert.Error(t, report.FatalError())
|
|
}
|
|
|
|
type DummyAPI struct {
|
|
registerSource func(context.Context, string, sourcespb.SourceType) (SourceID, error)
|
|
getJobID func(context.Context, SourceID) (JobID, error)
|
|
}
|
|
|
|
func (api DummyAPI) RegisterSource(ctx context.Context, name string, kind sourcespb.SourceType) (SourceID, error) {
|
|
return api.registerSource(ctx, name, kind)
|
|
}
|
|
|
|
func (api DummyAPI) GetJobID(ctx context.Context, id SourceID) (JobID, error) {
|
|
return api.getJobID(ctx, id)
|
|
}
|
|
|
|
// Chunk method that has a custom callback for the Chunks method.
|
|
type callbackChunker struct {
|
|
cb func(context.Context, chan *Chunk) error
|
|
}
|
|
|
|
func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
|
|
return c.cb(ctx, ch)
|
|
}
|
|
func (c callbackChunker) Enumerate(context.Context, UnitReporter) error { return nil }
|
|
func (c callbackChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return nil }
|
|
|
|
func TestSourceManagerCancelRun(t *testing.T) {
|
|
mgr := NewManager(WithBufferedOutput(8))
|
|
var returnedErr error
|
|
source, err := buildDummy(callbackChunker{func(ctx context.Context, _ chan *Chunk) error {
|
|
// The context passed to Chunks should get cancelled when ref.CancelRun() is called.
|
|
<-ctx.Done()
|
|
returnedErr = fmt.Errorf("oh no: %w", ctx.Err())
|
|
return returnedErr
|
|
}})
|
|
assert.NoError(t, err)
|
|
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
|
|
cancelErr := fmt.Errorf("abort! abort!")
|
|
ref.CancelRun(cancelErr)
|
|
<-ref.Done()
|
|
assert.Error(t, ref.Snapshot().FatalError())
|
|
assert.True(t, errors.Is(ref.Snapshot().FatalError(), returnedErr))
|
|
assert.True(t, errors.Is(ref.Snapshot().FatalErrors(), cancelErr))
|
|
}
|
|
|
|
func TestSourceManagerAvailableCapacity(t *testing.T) {
|
|
mgr := NewManager(WithConcurrentSources(1337))
|
|
start, end := make(chan struct{}), make(chan struct{})
|
|
source, err := buildDummy(callbackChunker{func(context.Context, chan *Chunk) error {
|
|
start <- struct{}{} // Send start signal.
|
|
<-end // Wait for end signal.
|
|
return nil
|
|
}})
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, 1337, mgr.AvailableCapacity())
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
|
|
<-start // Wait for start signal.
|
|
assert.Equal(t, 1336, mgr.AvailableCapacity())
|
|
end <- struct{}{} // Send end signal.
|
|
<-ref.Done() // Wait for the job to finish.
|
|
assert.Equal(t, 1337, mgr.AvailableCapacity())
|
|
}
|
|
|
|
func TestSourceManagerUnitHook(t *testing.T) {
|
|
hook := NewUnitHook(context.TODO())
|
|
|
|
input := []unitChunk{
|
|
{unit: "one", output: "bar"},
|
|
{unit: "two", err: "oh no"},
|
|
{unit: "three", err: "not again"},
|
|
}
|
|
mgr := NewManager(
|
|
WithBufferedOutput(8),
|
|
WithSourceUnits(), WithConcurrentUnits(1),
|
|
WithReportHook(hook),
|
|
)
|
|
source, err := buildDummy(&unitChunker{input})
|
|
assert.NoError(t, err)
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
|
|
metrics := hook.UnitMetrics()
|
|
assert.Equal(t, 3, len(metrics))
|
|
sort.Slice(metrics, func(i, j int) bool {
|
|
return metrics[i].EndTime.Before(metrics[j].EndTime)
|
|
})
|
|
m0, m1, m2 := metrics[0], metrics[1], metrics[2]
|
|
|
|
assert.Equal(t, "one", m0.Unit.SourceUnitID())
|
|
assert.Equal(t, uint64(1), m0.TotalChunks)
|
|
assert.Equal(t, uint64(3), m0.TotalBytes)
|
|
assert.NotZero(t, m0.StartTime)
|
|
assert.NotZero(t, m0.EndTime)
|
|
assert.NotZero(t, m0.ElapsedTime())
|
|
assert.Equal(t, 0, len(m0.Errors))
|
|
|
|
assert.Equal(t, "two", m1.Unit.SourceUnitID())
|
|
assert.Equal(t, uint64(0), m1.TotalChunks)
|
|
assert.Equal(t, uint64(0), m1.TotalBytes)
|
|
assert.NotZero(t, m1.StartTime)
|
|
assert.NotZero(t, m1.EndTime)
|
|
assert.NotZero(t, m1.ElapsedTime())
|
|
assert.Equal(t, 1, len(m1.Errors))
|
|
|
|
assert.Equal(t, "three", m2.Unit.SourceUnitID())
|
|
assert.Equal(t, uint64(0), m2.TotalChunks)
|
|
assert.Equal(t, uint64(0), m2.TotalBytes)
|
|
assert.NotZero(t, m2.StartTime)
|
|
assert.NotZero(t, m2.EndTime)
|
|
assert.NotZero(t, m2.ElapsedTime())
|
|
assert.Equal(t, 1, len(m2.Errors))
|
|
}
|
|
|
|
// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if
|
|
// they aren't handled fast enough.
|
|
func TestSourceManagerUnitHookNoBlock(t *testing.T) {
|
|
var evictedKeys []string
|
|
cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) {
|
|
evictedKeys = append(evictedKeys, key)
|
|
})
|
|
hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache))
|
|
|
|
input := []unitChunk{
|
|
{unit: "one", output: "bar"},
|
|
{unit: "two", err: "oh no"},
|
|
{unit: "three", err: "not again"},
|
|
}
|
|
mgr := NewManager(
|
|
WithBufferedOutput(8),
|
|
WithSourceUnits(), WithConcurrentUnits(1),
|
|
WithReportHook(hook),
|
|
)
|
|
source, err := buildDummy(&unitChunker{input})
|
|
assert.NoError(t, err)
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
|
|
assert.Equal(t, 2, len(evictedKeys))
|
|
metrics := hook.UnitMetrics()
|
|
assert.Equal(t, 1, len(metrics))
|
|
assert.Equal(t, "three", metrics[0].Unit.SourceUnitID())
|
|
}
|
|
|
|
// TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for
|
|
// sources that don't support units.
|
|
func TestSourceManagerUnitHookNoUnits(t *testing.T) {
|
|
hook := NewUnitHook(context.TODO())
|
|
|
|
mgr := NewManager(
|
|
WithBufferedOutput(8),
|
|
WithReportHook(hook),
|
|
)
|
|
source, err := buildDummy(&counterChunker{count: 5})
|
|
assert.NoError(t, err)
|
|
|
|
ref, err := mgr.Run(context.Background(), "dummy", source)
|
|
assert.NoError(t, err)
|
|
<-ref.Done()
|
|
|
|
metrics := hook.UnitMetrics()
|
|
assert.Equal(t, 1, len(metrics))
|
|
|
|
m := metrics[0]
|
|
assert.Equal(t, nil, m.Unit)
|
|
assert.Equal(t, uint64(5), m.TotalChunks)
|
|
assert.Equal(t, uint64(5), m.TotalBytes)
|
|
assert.NotZero(t, m.StartTime)
|
|
assert.NotZero(t, m.EndTime)
|
|
assert.NotZero(t, m.ElapsedTime())
|
|
assert.Equal(t, 0, len(m.Errors))
|
|
}
|