add cancellation to s3 source

This commit is contained in:
Ahrav Dutta 2024-11-08 09:57:59 -08:00
parent 781157ae36
commit 518e112bf5
2 changed files with 86 additions and 9 deletions

View file

@ -101,11 +101,12 @@ func (s *Source) Init(
func (s *Source) Validate(ctx context.Context) []error {
var errs []error
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
if len(roleErrs) > 0 {
errs = append(errs, roleErrs...)
}
return nil
}
if err := s.visitRoles(ctx, visitor); err != nil {
@ -199,24 +200,51 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
return bucketsToScan, nil
}
// workerSignal provides thread-safe tracking of cancellation state across multiple
// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context
// is cancelled during bucket scanning operations.
//
// This type serves several key purposes:
// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool,
// not error. workerSignal bridges this gap by providing a way to communicate
// cancellation back to the caller.
// 2. The pageChunker spawns multiple concurrent workers to process objects within
// each page. workerSignal enables these workers to detect and respond to
// cancellation signals.
// 3. Ensures proper progress tracking by allowing the main scanning loop to detect
// when workers have been cancelled and handle cleanup appropriately.
type workerSignal struct{ cancelled atomic.Bool }
// newWorkerSignal creates a new workerSignal
func newWorkerSignal() *workerSignal { return new(workerSignal) }
// MarkCancelled marks that a context cancellation was detected.
func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) }
// WasCancelled returns true if context cancellation was detected.
func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() }
func (s *Source) scanBuckets(
ctx context.Context,
client *s3.S3,
role string,
bucketsToScan []string,
chunksChan chan *sources.Chunk,
) {
) error {
var objectCount uint64
if role != "" {
ctx = context.WithValue(ctx, "role", role)
}
// Create worker signal to track cancellation across page processing.
workerSignal := newWorkerSignal()
for i, bucket := range bucketsToScan {
ctx := context.WithValue(ctx, "bucket", bucket)
if common.IsDone(ctx) {
return
return ctx.Err()
}
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
@ -233,10 +261,15 @@ func (s *Source) scanBuckets(
err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
func(page *s3.ListObjectsV2Output, _ bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
return true
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount, workerSignal)
return !workerSignal.WasCancelled()
})
// Check if we stopped due to cancellation.
if workerSignal.WasCancelled() {
return ctx.Err()
}
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
@ -255,12 +288,14 @@ func (s *Source) scanBuckets(
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
"",
)
return nil
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
}
return s.visitRoles(ctx, visitor)
@ -299,6 +334,7 @@ func (s *Source) pageChunker(
errorCount *sync.Map,
pageNumber int,
objectCount *uint64,
workerSignal *workerSignal,
) {
for _, obj := range page.Contents {
if obj == nil {
@ -314,6 +350,7 @@ func (s *Source) pageChunker(
)
if common.IsDone(ctx) {
workerSignal.MarkCancelled()
return
}
@ -343,6 +380,10 @@ func (s *Source) pageChunker(
s.jobPool.Go(func() error {
defer common.RecoverWithExit(ctx)
if common.IsDone(ctx) {
workerSignal.MarkCancelled()
return ctx.Err()
}
if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
@ -488,7 +529,7 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
) error {
roles := s.conn.GetRoles()
if len(roles) == 0 {
@ -506,7 +547,9 @@ func (s *Source) visitRoles(
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
}
f(ctx, client, role, bucketsToScan)
if err := f(ctx, client, role, bucketsToScan); err != nil {
return err
}
}
return nil

View file

@ -10,6 +10,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
@ -216,3 +217,36 @@ func TestSource_Validate(t *testing.T) {
})
}
}
// TestSourceCancellation tests that the source can be cancelled and that it does not complete
// when the context is cancelled.
func TestSourceCancellation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
defer cancel()
src := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test"},
}
conn, err := anypb.New(connection)
require.NoError(t, err)
err = src.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = src.Chunks(ctx, chunksCh)
assert.Error(t, err, "expected context.Cancelled error")
}()
wantChunkCount := 9637
got := 0
for range chunksCh {
got++
}
assert.Less(t, int(src.PercentComplete), 100, "source should not have completed")
assert.Less(t, got, wantChunkCount,
"more chunks than expected were received")
}