diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 27c9e9b4e..0f610586e 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -101,11 +101,12 @@ func (s *Source) Init( func (s *Source) Validate(ctx context.Context) []error { var errs []error - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets) if len(roleErrs) > 0 { errs = append(errs, roleErrs...) } + return nil } if err := s.visitRoles(ctx, visitor); err != nil { @@ -199,24 +200,51 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { return bucketsToScan, nil } +// workerSignal provides thread-safe tracking of cancellation state across multiple +// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context +// is cancelled during bucket scanning operations. +// +// This type serves several key purposes: +// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool, +// not error. workerSignal bridges this gap by providing a way to communicate +// cancellation back to the caller. +// 2. The pageChunker spawns multiple concurrent workers to process objects within +// each page. workerSignal enables these workers to detect and respond to +// cancellation signals. +// 3. Ensures proper progress tracking by allowing the main scanning loop to detect +// when workers have been cancelled and handle cleanup appropriately. +type workerSignal struct{ cancelled atomic.Bool } + +// newWorkerSignal creates a new workerSignal +func newWorkerSignal() *workerSignal { return new(workerSignal) } + +// MarkCancelled marks that a context cancellation was detected. +func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) } + +// WasCancelled returns true if context cancellation was detected. +func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() } + func (s *Source) scanBuckets( ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk, -) { +) error { var objectCount uint64 if role != "" { ctx = context.WithValue(ctx, "role", role) } + // Create worker signal to track cancellation across page processing. + workerSignal := newWorkerSignal() + for i, bucket := range bucketsToScan { ctx := context.WithValue(ctx, "bucket", bucket) if common.IsDone(ctx) { - return + return ctx.Err() } s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "") @@ -233,10 +261,15 @@ func (s *Source) scanBuckets( err = regionalClient.ListObjectsV2PagesWithContext( ctx, &s3.ListObjectsV2Input{Bucket: &bucket}, func(page *s3.ListObjectsV2Output, _ bool) bool { - s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount) - return true + s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount, workerSignal) + return !workerSignal.WasCancelled() }) + // Check if we stopped due to cancellation. + if workerSignal.WasCancelled() { + return ctx.Err() + } + if err != nil { if role == "" { ctx.Logger().Error(err, "could not list objects in bucket") @@ -255,12 +288,14 @@ func (s *Source) scanBuckets( fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "", ) + + return nil } // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error { - visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) { - s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan) + visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error { + return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan) } return s.visitRoles(ctx, visitor) @@ -299,6 +334,7 @@ func (s *Source) pageChunker( errorCount *sync.Map, pageNumber int, objectCount *uint64, + workerSignal *workerSignal, ) { for _, obj := range page.Contents { if obj == nil { @@ -314,6 +350,7 @@ func (s *Source) pageChunker( ) if common.IsDone(ctx) { + workerSignal.MarkCancelled() return } @@ -343,6 +380,10 @@ func (s *Source) pageChunker( s.jobPool.Go(func() error { defer common.RecoverWithExit(ctx) + if common.IsDone(ctx) { + workerSignal.MarkCancelled() + return ctx.Err() + } if strings.HasSuffix(*obj.Key, "/") { ctx.Logger().V(5).Info("Skipping directory") @@ -488,7 +529,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr // If no roles are configured, it will call the function with an empty role ARN. func (s *Source) visitRoles( ctx context.Context, - f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string), + f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error, ) error { roles := s.conn.GetRoles() if len(roles) == 0 { @@ -506,7 +547,9 @@ func (s *Source) visitRoles( return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err) } - f(ctx, client, role, bucketsToScan) + if err := f(ctx, client, role, bucketsToScan); err != nil { + return err + } } return nil diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index 1832eeb30..ad68f696d 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/common" @@ -216,3 +217,36 @@ func TestSource_Validate(t *testing.T) { }) } } + +// TestSourceCancellation tests that the source can be cancelled and that it does not complete +// when the context is cancelled. +func TestSourceCancellation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + defer cancel() + + src := Source{} + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"trufflesec-ahrav-test"}, + } + conn, err := anypb.New(connection) + require.NoError(t, err) + + err = src.Init(ctx, "test name", 0, 0, false, conn, 1) + chunksCh := make(chan *sources.Chunk) + go func() { + defer close(chunksCh) + err = src.Chunks(ctx, chunksCh) + assert.Error(t, err, "expected context.Cancelled error") + }() + + wantChunkCount := 9637 + got := 0 + for range chunksCh { + got++ + } + + assert.Less(t, int(src.PercentComplete), 100, "source should not have completed") + assert.Less(t, got, wantChunkCount, + "more chunks than expected were received") +}