mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-13 00:17:18 +00:00
add cancellation to s3 source
This commit is contained in:
parent
781157ae36
commit
518e112bf5
2 changed files with 86 additions and 9 deletions
|
@ -101,11 +101,12 @@ func (s *Source) Init(
|
|||
|
||||
func (s *Source) Validate(ctx context.Context) []error {
|
||||
var errs []error
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
|
||||
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
|
||||
if len(roleErrs) > 0 {
|
||||
errs = append(errs, roleErrs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.visitRoles(ctx, visitor); err != nil {
|
||||
|
@ -199,24 +200,51 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
|
|||
return bucketsToScan, nil
|
||||
}
|
||||
|
||||
// workerSignal provides thread-safe tracking of cancellation state across multiple
|
||||
// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context
|
||||
// is cancelled during bucket scanning operations.
|
||||
//
|
||||
// This type serves several key purposes:
|
||||
// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool,
|
||||
// not error. workerSignal bridges this gap by providing a way to communicate
|
||||
// cancellation back to the caller.
|
||||
// 2. The pageChunker spawns multiple concurrent workers to process objects within
|
||||
// each page. workerSignal enables these workers to detect and respond to
|
||||
// cancellation signals.
|
||||
// 3. Ensures proper progress tracking by allowing the main scanning loop to detect
|
||||
// when workers have been cancelled and handle cleanup appropriately.
|
||||
type workerSignal struct{ cancelled atomic.Bool }
|
||||
|
||||
// newWorkerSignal creates a new workerSignal
|
||||
func newWorkerSignal() *workerSignal { return new(workerSignal) }
|
||||
|
||||
// MarkCancelled marks that a context cancellation was detected.
|
||||
func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) }
|
||||
|
||||
// WasCancelled returns true if context cancellation was detected.
|
||||
func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() }
|
||||
|
||||
func (s *Source) scanBuckets(
|
||||
ctx context.Context,
|
||||
client *s3.S3,
|
||||
role string,
|
||||
bucketsToScan []string,
|
||||
chunksChan chan *sources.Chunk,
|
||||
) {
|
||||
) error {
|
||||
var objectCount uint64
|
||||
|
||||
if role != "" {
|
||||
ctx = context.WithValue(ctx, "role", role)
|
||||
}
|
||||
|
||||
// Create worker signal to track cancellation across page processing.
|
||||
workerSignal := newWorkerSignal()
|
||||
|
||||
for i, bucket := range bucketsToScan {
|
||||
ctx := context.WithValue(ctx, "bucket", bucket)
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
return
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
|
||||
|
@ -233,10 +261,15 @@ func (s *Source) scanBuckets(
|
|||
err = regionalClient.ListObjectsV2PagesWithContext(
|
||||
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
|
||||
func(page *s3.ListObjectsV2Output, _ bool) bool {
|
||||
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
|
||||
return true
|
||||
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount, workerSignal)
|
||||
return !workerSignal.WasCancelled()
|
||||
})
|
||||
|
||||
// Check if we stopped due to cancellation.
|
||||
if workerSignal.WasCancelled() {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if role == "" {
|
||||
ctx.Logger().Error(err, "could not list objects in bucket")
|
||||
|
@ -255,12 +288,14 @@ func (s *Source) scanBuckets(
|
|||
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
|
||||
"",
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
|
||||
return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
|
||||
}
|
||||
|
||||
return s.visitRoles(ctx, visitor)
|
||||
|
@ -299,6 +334,7 @@ func (s *Source) pageChunker(
|
|||
errorCount *sync.Map,
|
||||
pageNumber int,
|
||||
objectCount *uint64,
|
||||
workerSignal *workerSignal,
|
||||
) {
|
||||
for _, obj := range page.Contents {
|
||||
if obj == nil {
|
||||
|
@ -314,6 +350,7 @@ func (s *Source) pageChunker(
|
|||
)
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
workerSignal.MarkCancelled()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -343,6 +380,10 @@ func (s *Source) pageChunker(
|
|||
|
||||
s.jobPool.Go(func() error {
|
||||
defer common.RecoverWithExit(ctx)
|
||||
if common.IsDone(ctx) {
|
||||
workerSignal.MarkCancelled()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if strings.HasSuffix(*obj.Key, "/") {
|
||||
ctx.Logger().V(5).Info("Skipping directory")
|
||||
|
@ -488,7 +529,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
|
|||
// If no roles are configured, it will call the function with an empty role ARN.
|
||||
func (s *Source) visitRoles(
|
||||
ctx context.Context,
|
||||
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
|
||||
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
|
||||
) error {
|
||||
roles := s.conn.GetRoles()
|
||||
if len(roles) == 0 {
|
||||
|
@ -506,7 +547,9 @@ func (s *Source) visitRoles(
|
|||
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
|
||||
}
|
||||
|
||||
f(ctx, client, role, bucketsToScan)
|
||||
if err := f(ctx, client, role, bucketsToScan); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
|
@ -216,3 +217,36 @@ func TestSource_Validate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSourceCancellation tests that the source can be cancelled and that it does not complete
|
||||
// when the context is cancelled.
|
||||
func TestSourceCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
|
||||
defer cancel()
|
||||
|
||||
src := Source{}
|
||||
connection := &sourcespb.S3{
|
||||
Credential: &sourcespb.S3_Unauthenticated{},
|
||||
Buckets: []string{"trufflesec-ahrav-test"},
|
||||
}
|
||||
conn, err := anypb.New(connection)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = src.Init(ctx, "test name", 0, 0, false, conn, 1)
|
||||
chunksCh := make(chan *sources.Chunk)
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err = src.Chunks(ctx, chunksCh)
|
||||
assert.Error(t, err, "expected context.Cancelled error")
|
||||
}()
|
||||
|
||||
wantChunkCount := 9637
|
||||
got := 0
|
||||
for range chunksCh {
|
||||
got++
|
||||
}
|
||||
|
||||
assert.Less(t, int(src.PercentComplete), 100, "source should not have completed")
|
||||
assert.Less(t, got, wantChunkCount,
|
||||
"more chunks than expected were received")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue