This commit is contained in:
Ahrav Dutta 2024-11-07 12:41:36 -08:00
parent 334a4f9ef2
commit dd206643fe
5 changed files with 487 additions and 466 deletions

View file

@ -21,9 +21,7 @@ type ProgressTracker struct {
// completedObjects tracks which indices in the current page have been processed.
sync.Mutex
completedObjects []bool
baseCompleted int32 // Track completed count from previous pages
currentPageSize int32 // Track the current page size to avoid double counting
completionOrder []int // Track the order in which objects complete
// progress holds the scan's overall progress state and enables persistence.
progress *sources.Progress // Reference to source's Progress
@ -34,17 +32,18 @@ const defaultMaxObjectsPerPage = 1000
// NewProgressTracker creates a new progress tracker for S3 scanning operations.
// The enabled parameter determines if progress tracking is active, and progress
// provides the underlying mechanism for persisting scan state.
func NewProgressTracker(ctx context.Context, enabled bool, progress *sources.Progress) *ProgressTracker {
func NewProgressTracker(_ context.Context, enabled bool, progress *sources.Progress) (*ProgressTracker, error) {
if progress == nil {
ctx.Logger().Info("Nil progress provided. Progress initialized.")
progress = new(sources.Progress)
return nil, errors.New("Nil progress provided; progress is required for tracking")
}
return &ProgressTracker{
// We are resuming if we have completed objects from a previous scan.
completedObjects: make([]bool, defaultMaxObjectsPerPage),
completionOrder: make([]int, 0, defaultMaxObjectsPerPage),
enabled: enabled,
progress: progress,
}
}, nil
}
// Reset prepares the tracker for a new page of objects by clearing the completion state.
@ -56,9 +55,8 @@ func (p *ProgressTracker) Reset(_ context.Context) {
p.Lock()
defer p.Unlock()
// Store the current completed count before moving to next page.
p.baseCompleted = p.progress.SectionsCompleted
p.currentPageSize = 0
p.completedObjects = make([]bool, defaultMaxObjectsPerPage)
p.completionOrder = make([]int, 0, defaultMaxObjectsPerPage)
}
// ResumeInfo represents the state needed to resume an interrupted operation.
@ -93,19 +91,12 @@ func (p *ProgressTracker) GetResumePoint(ctx context.Context) (ResumeInfo, error
return resume, nil
}
return ResumeInfo{
CurrentBucket: resumeInfo.CurrentBucket,
StartAfter: resumeInfo.StartAfter,
}, nil
return ResumeInfo{CurrentBucket: resumeInfo.CurrentBucket, StartAfter: resumeInfo.StartAfter}, nil
}
// Complete marks the entire scanning operation as finished and clears the resume state.
// This should only be called once all scanning operations are complete.
func (p *ProgressTracker) Complete(_ context.Context, message string) error {
if !p.enabled {
return nil
}
// Preserve existing progress counters while clearing resume state.
p.progress.SetProgressComplete(
int(p.progress.SectionsCompleted),
@ -126,8 +117,7 @@ func (p *ProgressTracker) Complete(_ context.Context, message string) error {
//
// This approach ensures scan reliability by only checkpointing consecutively completed
// objects. While this may result in re-scanning some objects when resuming, it guarantees
// no objects are missed in case of interruption. The linear search through page contents
// is efficient given the fixed maximum page size of 1000 objects.
// no objects are missed in case of interruption.
func (p *ProgressTracker) UpdateObjectProgress(
ctx context.Context,
completedIdx int,
@ -148,43 +138,47 @@ func (p *ProgressTracker) UpdateObjectProgress(
p.Lock()
defer p.Unlock()
// Update remaining count only once per page.
pageSize := int32(len(pageContents))
if p.currentPageSize == 0 {
p.progress.SectionsRemaining += pageSize
p.currentPageSize = pageSize
// Only track completion if this is the first time this index is marked complete.
if !p.completedObjects[completedIdx] {
p.completedObjects[completedIdx] = true
p.completionOrder = append(p.completionOrder, completedIdx)
}
p.completedObjects[completedIdx] = true
// Find the highest safe checkpoint we can create.
lastSafeIdx := -1
var safeIndices [defaultMaxObjectsPerPage]bool
// Mark all completed indices.
for _, idx := range p.completionOrder {
safeIndices[idx] = true
}
// Find the highest consecutive completed index.
lastConsecutiveIdx := -1
for i := 0; i <= completedIdx; i++ {
if !p.completedObjects[i] {
for i := range len(p.completedObjects) {
if !safeIndices[i] {
break
}
lastConsecutiveIdx = i
lastSafeIdx = i
}
// Update progress if we have at least one completed object.
if lastConsecutiveIdx < 0 {
if lastSafeIdx < 0 {
return nil
}
obj := pageContents[lastConsecutiveIdx]
obj := pageContents[lastSafeIdx]
info := &ResumeInfo{CurrentBucket: bucket, StartAfter: *obj.Key}
encoded, err := json.Marshal(info)
if err != nil {
return err
}
// Set the total completed as base (from previous pages) plus consecutive completions in this page.
completedCount := p.baseCompleted + int32(lastConsecutiveIdx+1)
// Purposefully avoid updating any progress counts.
// Only update resume info.
p.progress.SetProgressComplete(
int(completedCount),
int(p.progress.SectionsCompleted),
int(p.progress.SectionsRemaining),
fmt.Sprintf("Processing: %s/%s", bucket, *obj.Key),
p.progress.Message,
string(encoded),
)
return nil

View file

@ -7,11 +7,61 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
func TestProgressTrackerResumption(t *testing.T) {
ctx := context.Background()
// First scan - process 6 objects then interrupt.
initialProgress := &sources.Progress{}
tracker, err := NewProgressTracker(ctx, true, initialProgress)
require.NoError(t, err)
firstPage := &s3.ListObjectsV2Output{
Contents: make([]*s3.Object, 12), // Total of 12 objects
}
for i := range 12 {
key := fmt.Sprintf("key-%d", i)
firstPage.Contents[i] = &s3.Object{Key: &key}
}
// Process first 6 objects.
for i := range 6 {
err := tracker.UpdateObjectProgress(ctx, i, "test-bucket", firstPage.Contents)
assert.NoError(t, err)
}
// Verify resume info is set correctly.
resumeInfo, err := tracker.GetResumePoint(ctx)
require.NoError(t, err)
assert.Equal(t, "test-bucket", resumeInfo.CurrentBucket)
assert.Equal(t, "key-5", resumeInfo.StartAfter)
// Resume scan with existing progress.
resumeTracker, err := NewProgressTracker(ctx, true, initialProgress)
require.NoError(t, err)
resumePage := &s3.ListObjectsV2Output{
Contents: firstPage.Contents[6:], // Remaining 6 objects
}
// Process remaining objects.
for i := range len(resumePage.Contents) {
err := resumeTracker.UpdateObjectProgress(ctx, i, "test-bucket", resumePage.Contents)
assert.NoError(t, err)
}
// Verify final resume info.
finalResumeInfo, err := resumeTracker.GetResumePoint(ctx)
require.NoError(t, err)
assert.Equal(t, "test-bucket", finalResumeInfo.CurrentBucket)
assert.Equal(t, "key-11", finalResumeInfo.StartAfter)
}
func TestProgressTrackerReset(t *testing.T) {
tests := []struct {
name string
@ -19,7 +69,6 @@ func TestProgressTrackerReset(t *testing.T) {
}{
{name: "reset with enabled tracker", enabled: true},
{name: "reset with disabled tracker", enabled: false},
{name: "reset with zero capacity", enabled: true},
}
for _, tt := range tests {
@ -28,19 +77,28 @@ func TestProgressTrackerReset(t *testing.T) {
ctx := context.Background()
progress := new(sources.Progress)
tracker := NewProgressTracker(ctx, tt.enabled, progress)
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
require.NoError(t, err)
tracker.completedObjects[1] = true
tracker.completedObjects[2] = true
tracker.Reset(ctx)
if !tt.enabled {
assert.Equal(t, defaultMaxObjectsPerPage, len(tracker.completedObjects), "Reset did not clear completed objects")
return
}
assert.Equal(t, defaultMaxObjectsPerPage, len(tracker.completedObjects),
"Reset changed the length of completed objects")
assert.Equal(t, 0, len(tracker.completedObjects), "Reset did not clear completed objects")
if tt.enabled {
// All values should be false after reset.
for i, isCompleted := range tracker.completedObjects {
assert.False(t, isCompleted,
"Reset did not clear completed object at index %d", i)
}
// Completion order should be empty.
assert.Equal(t, 0, len(tracker.completionOrder),
"Reset did not clear completion order")
}
})
}
}
@ -116,239 +174,51 @@ func TestGetResumePoint(t *testing.T) {
}
}
func setupTestTracker(t *testing.T, enabled bool, progress *sources.Progress, pageSize int) (*ProgressTracker, *s3.ListObjectsV2Output) {
t.Helper()
tracker := NewProgressTracker(context.Background(), enabled, progress)
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, pageSize)}
for i := range pageSize {
key := fmt.Sprintf("key-%d", i)
page.Contents[i] = &s3.Object{Key: &key}
}
return tracker, page
}
func TestProgressTrackerUpdateProgressDisabled(t *testing.T) {
t.Parallel()
progress := new(sources.Progress)
tracker, page := setupTestTracker(t, false, progress, 5)
err := tracker.UpdateObjectProgress(context.Background(), 1, "test-bucket", page.Contents)
assert.NoError(t, err, "Error updating progress when tracker disabled")
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
}
func TestProgressTrackerUpdateProgressCompletedIdxOOR(t *testing.T) {
t.Parallel()
progress := new(sources.Progress)
tracker, page := setupTestTracker(t, true, progress, 5)
err := tracker.UpdateObjectProgress(context.Background(), 1001, "test-bucket", page.Contents)
assert.Error(t, err, "Expected error when completedIdx out of range")
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
}
func TestProgressTrackerSequence(t *testing.T) {
tests := []struct {
name string
description string
// Each update is a sequence of {completedIdx, expectedCompleted, expectedRemaining}
updates [][3]int
pageSize int
}{
{
name: "multiple updates same page",
description: "Verify remaining count isn't doubled and completed accumulates correctly",
pageSize: 5,
updates: [][3]int{
{0, 1, 5}, // First object - should set remaining to 5
{1, 2, 5}, // Second object - remaining should stay 5
{2, 3, 5}, // Third object - remaining should stay 5
{4, 3, 5}, // Gap at index 3 - completed should stay at 3
},
},
{
name: "across page boundaries",
description: "Verify completed count accumulates across pages",
pageSize: 3,
updates: [][3]int{
// First page
{0, 1, 3},
{1, 2, 3},
{2, 3, 3},
// Reset and start new page.
{0, 4, 6}, // baseCompleted(3) + current(1)
{1, 5, 6}, // baseCompleted(3) + current(2)
{2, 6, 6}, // baseCompleted(3) + current(3)
},
},
{
name: "incomplete page transition",
description: "Verify incomplete page properly sets base completed",
pageSize: 4,
updates: [][3]int{
// First page - only complete first 2.
{0, 1, 4},
{1, 2, 4},
// Skip 2,3 and move to next page.
// Reset and start new page.
{0, 3, 8}, // baseCompleted(2) + current(1)
{1, 4, 8}, // baseCompleted(2) + current(2)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
progress := new(sources.Progress)
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
pageCount := 0
for i, update := range tt.updates {
completedIdx, expectedCompleted, expectedRemaining := update[0], update[1], update[2]
// If this update starts a new page.
if completedIdx == 0 && i > 0 {
pageCount++
tracker.Reset(ctx)
// Create new page with same size.
page = &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
for j := range tt.pageSize {
key := fmt.Sprintf("page%d-key-%d", pageCount, j)
page.Contents[j] = &s3.Object{Key: &key}
}
}
err := tracker.UpdateObjectProgress(ctx, completedIdx, "test-bucket", page.Contents)
assert.NoError(t, err, "Unexpected error updating progress")
assert.Equal(t, expectedCompleted, int(progress.SectionsCompleted),
"Incorrect completed count at update %d", i)
assert.Equal(t, expectedRemaining, int(progress.SectionsRemaining),
"Incorrect remaining count at update %d", i)
}
})
}
}
func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
func TestProgressTrackerUpdateProgress(t *testing.T) {
tests := []struct {
name string
description string
completedIdx int
pageSize int
preCompleted map[int]bool
expectedKey string
expectedCompleted int
expectedRemaining int
preCompleted []int
expectedKey string
}{
{
name: "first object completed",
description: "Basic case - completing first object",
completedIdx: 0,
pageSize: 3,
expectedKey: "key-0",
expectedCompleted: 1, // Only first object completed
expectedRemaining: 3, // Total objects in page
name: "first object completed",
description: "Basic case - completing first object",
completedIdx: 0,
pageSize: 3,
expectedKey: "key-0",
},
{
name: "completing missing middle",
description: "Completing object when previous is done",
completedIdx: 1,
pageSize: 3,
preCompleted: map[int]bool{0: true},
expectedKey: "key-1",
expectedCompleted: 2, // First two objects completed
expectedRemaining: 3, // Total objects in page
name: "completing missing middle",
description: "Completing object when previous is done",
completedIdx: 1,
pageSize: 3,
preCompleted: []int{0},
expectedKey: "key-1",
},
{
name: "completing first with last done",
description: "Completing first object when last is already done",
completedIdx: 0,
pageSize: 3,
preCompleted: map[int]bool{2: true},
expectedKey: "key-0",
expectedCompleted: 1, // Only first object counts due to gap
expectedRemaining: 3, // Total objects in page
},
{
name: "all objects completed in order",
description: "Completing final object in sequence",
completedIdx: 2,
pageSize: 3,
preCompleted: map[int]bool{0: true, 1: true},
expectedKey: "key-2",
expectedCompleted: 3, // All objects completed
expectedRemaining: 3, // Total objects in page
},
{
name: "completing middle gaps",
description: "Completing object with gaps in sequence",
completedIdx: 5,
pageSize: 10,
preCompleted: map[int]bool{0: true, 1: true, 2: true, 4: true},
expectedKey: "key-2", // Last consecutive completed
expectedCompleted: 3, // Only first 3 count due to gap
expectedRemaining: 10, // Total objects in page
},
{
name: "zero index with empty pre-completed",
description: "Edge case - minimum valid index",
completedIdx: 0,
pageSize: 1,
expectedKey: "key-0",
expectedCompleted: 1,
expectedRemaining: 1,
name: "all objects completed in order",
description: "Completing final object in sequence",
completedIdx: 2,
pageSize: 3,
preCompleted: []int{0, 1},
expectedKey: "key-2",
},
{
name: "last index in max page",
description: "Edge case - maximum page size boundary",
completedIdx: 999,
pageSize: 1000,
preCompleted: func() map[int]bool {
m := make(map[int]bool)
preCompleted: func() []int {
indices := make([]int, 999)
for i := range 999 {
m[i] = true
indices[i] = i
}
return m
return indices
}(),
expectedKey: "key-999",
expectedCompleted: 1000, // All objects completed
expectedRemaining: 1000, // Total objects in page
},
{
name: "all previous completed",
description: "Edge case - all previous indices completed",
completedIdx: 100,
pageSize: 101,
preCompleted: func() map[int]bool {
m := make(map[int]bool)
for i := range 100 {
m[i] = true
}
return m
}(),
expectedKey: "key-100",
expectedCompleted: 101, // All objects completed
expectedRemaining: 101, // Total objects in page
},
{
name: "large page number completion",
description: "Edge case - very large page number",
completedIdx: 5,
pageSize: 10,
preCompleted: map[int]bool{0: true, 1: true, 2: true, 3: true, 4: true},
expectedKey: "key-5",
expectedCompleted: 6, // First 6 objects completed
expectedRemaining: 10, // Total objects in page
expectedKey: "key-999",
},
}
@ -358,108 +228,34 @@ func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
ctx := context.Background()
progress := new(sources.Progress)
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
tracker := &ProgressTracker{
enabled: true,
progress: progress,
completedObjects: make([]bool, tt.pageSize),
completionOrder: make([]int, 0, tt.pageSize),
}
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
for i := range tt.pageSize {
key := fmt.Sprintf("key-%d", i)
page.Contents[i] = &s3.Object{Key: &key}
}
// Apply pre-completed indices in order.
if tt.preCompleted != nil {
for k, v := range tt.preCompleted {
tracker.completedObjects[k] = v
for _, idx := range tt.preCompleted {
tracker.completedObjects[idx] = true
tracker.completionOrder = append(tracker.completionOrder, idx)
}
}
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
assert.NoError(t, err, "Unexpected error updating progress")
// Verify resume info.
assert.NotEmpty(t, progress.EncodedResumeInfo, "Expected progress update")
var info ResumeInfo
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
assert.NoError(t, err, "Failed to decode resume info")
assert.Equal(t, tt.expectedKey, info.StartAfter, "Incorrect resume point")
// Verify progress counts.
assert.Equal(t, tt.expectedCompleted, int(progress.SectionsCompleted),
"Incorrect completed count")
assert.Equal(t, tt.expectedRemaining, int(progress.SectionsRemaining),
"Incorrect remaining count")
})
}
}
func TestProgressTrackerUpdateProgressNoResume(t *testing.T) {
tests := []struct {
name string
description string
completedIdx int
pageSize int
preCompleted map[int]bool
}{
{
name: "middle object completed first",
description: "Basic case - completing middle object first",
completedIdx: 1,
pageSize: 3,
},
{
name: "last object completed first",
description: "Basic case - completing last object first",
completedIdx: 2,
pageSize: 3,
},
{
name: "multiple gaps",
description: "Multiple non-consecutive completions",
completedIdx: 5,
pageSize: 10,
preCompleted: map[int]bool{1: true, 3: true, 4: true},
},
{
name: "alternating completion pattern",
description: "Edge case - alternating completed/uncompleted pattern",
completedIdx: 10,
pageSize: 20,
preCompleted: map[int]bool{2: true, 4: true, 6: true, 8: true},
},
{
name: "sparse completion pattern",
description: "Edge case - scattered completions with regular gaps",
completedIdx: 50,
pageSize: 100,
preCompleted: map[int]bool{10: true, 20: true, 30: true, 40: true},
},
{
name: "single gap breaks sequence",
description: "Edge case - single gap prevents resume info",
completedIdx: 50,
pageSize: 100,
preCompleted: func() map[int]bool {
m := make(map[int]bool)
for i := 1; i <= 49; i++ {
m[i] = true
}
m[49] = false
return m
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
progress := new(sources.Progress)
enabled := tt.name != "disabled tracker"
tracker, page := setupTestTracker(t, enabled, progress, tt.pageSize)
if tt.preCompleted != nil {
for k, v := range tt.preCompleted {
tracker.completedObjects[k] = v
}
}
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
assert.NoError(t, err, "Unexpected error updating progress")
assert.Empty(t, progress.EncodedResumeInfo, "Expected no progress update")
})
}
}
@ -469,98 +265,51 @@ func TestComplete(t *testing.T) {
name string
enabled bool
initialState struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}
completeMessage string
wantState struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}
}{
{
name: "marks completion with existing progress",
name: "marks completion with existing resume info",
enabled: true,
initialState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}{
sectionsCompleted: 100,
sectionsRemaining: 100,
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
message: "In progress",
resumeInfo: `{"current_bucket":"test-bucket","start_after":"some-key"}`,
message: "In progress",
},
completeMessage: "Scan complete",
wantState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}{
sectionsCompleted: 100, // Should preserve existing progress
sectionsRemaining: 100, // Should preserve existing progress
resumeInfo: "", // Should clear resume info
message: "Scan complete",
resumeInfo: "", // Should clear resume info
message: "Scan complete",
},
},
{
name: "disabled tracker",
enabled: false,
initialState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}{
sectionsCompleted: 50,
sectionsRemaining: 100,
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
message: "Should not change",
resumeInfo: "",
message: "Should not change",
},
completeMessage: "Completed",
wantState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
resumeInfo string
message string
}{
sectionsCompleted: 50,
sectionsRemaining: 100,
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
message: "Should not change",
},
},
{
name: "completes with special characters",
enabled: true,
initialState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
}{
sectionsCompleted: 75,
sectionsRemaining: 75,
resumeInfo: `{"CurrentBucket":"bucket","StartAfter":"key"}`,
message: "In progress",
},
completeMessage: "Completed scanning 特殊字符 & symbols !@#$%",
wantState: struct {
sectionsCompleted uint64
sectionsRemaining uint64
resumeInfo string
message string
}{
sectionsCompleted: 75,
sectionsRemaining: 75,
resumeInfo: "",
message: "Completed scanning 特殊字符 & symbols !@#$%",
resumeInfo: "",
message: "Completed",
},
},
}
@ -571,18 +320,15 @@ func TestComplete(t *testing.T) {
ctx := context.Background()
progress := &sources.Progress{
SectionsCompleted: int32(tt.initialState.sectionsCompleted),
SectionsRemaining: int32(tt.initialState.sectionsRemaining),
EncodedResumeInfo: tt.initialState.resumeInfo,
Message: tt.initialState.message,
}
tracker := NewProgressTracker(ctx, tt.enabled, progress)
err := tracker.Complete(ctx, tt.completeMessage)
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
assert.NoError(t, err)
err = tracker.Complete(ctx, tt.completeMessage)
assert.NoError(t, err)
assert.Equal(t, int32(tt.wantState.sectionsCompleted), progress.SectionsCompleted)
assert.Equal(t, int32(tt.wantState.sectionsRemaining), progress.SectionsRemaining)
assert.Equal(t, tt.wantState.resumeInfo, progress.EncodedResumeInfo)
assert.Equal(t, tt.wantState.message, progress.Message)
})

View file

@ -2,6 +2,7 @@ package s3
import (
"fmt"
"slices"
"strings"
"sync"
"sync/atomic"
@ -43,8 +44,10 @@ type Source struct {
jobID sources.JobID
verify bool
concurrency int
conn *sourcespb.S3
progressTracker *ProgressTracker
sources.Progress
conn *sourcespb.S3
errorCount *sync.Map
jobPool *errgroup.Group
@ -67,7 +70,7 @@ func (s *Source) JobID() sources.JobID { return s.jobID }
// Init returns an initialized AWS source
func (s *Source) Init(
_ context.Context,
ctx context.Context,
name string,
jobID sources.JobID,
sourceID sources.SourceID,
@ -90,6 +93,12 @@ func (s *Source) Init(
}
s.conn = &conn
var err error
s.progressTracker, err = NewProgressTracker(ctx, conn.GetEnableResumption(), &s.Progress)
if err != nil {
return err
}
s.setMaxObjectSize(conn.GetMaxObjectSize())
if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
@ -101,11 +110,12 @@ func (s *Source) Init(
func (s *Source) Validate(ctx context.Context) []error {
var errs []error
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
if len(roleErrs) > 0 {
errs = append(errs, roleErrs...)
}
return nil
}
if err := s.visitRoles(ctx, visitor); err != nil {
@ -173,9 +183,16 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {
return s3.New(sess), nil
}
// IAM identity needs s3:ListBuckets permission
// getBucketsToScan returns a list of S3 buckets to scan.
// If the connection has a list of buckets specified, those are returned.
// Otherwise, it lists all buckets the client has access to and filters out the ignored ones.
// The list of buckets is sorted lexicographically to ensure consistent ordering,
// which allows resuming scanning from the same place if the scan is interrupted.
//
// Note: The IAM identity needs the s3:ListBuckets permission.
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
slices.Sort(buckets)
return buckets, nil
}
@ -196,32 +213,91 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
bucketsToScan = append(bucketsToScan, name)
}
}
slices.Sort(bucketsToScan)
return bucketsToScan, nil
}
// workerSignal provides thread-safe tracking of cancellation state across multiple
// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context
// is cancelled during bucket scanning operations.
//
// This type serves several key purposes:
// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool,
// not error. workerSignal bridges this gap by providing a way to communicate
// cancellation back to the caller.
// 2. The pageChunker spawns multiple concurrent workers to process objects within
// each page. workerSignal enables these workers to detect and respond to
// cancellation signals.
// 3. Ensures proper progress tracking by allowing the main scanning loop to detect
// when workers have been cancelled and handle cleanup appropriately.
type workerSignal struct{ cancelled atomic.Bool }
// newWorkerSignal creates a new workerSignal
func newWorkerSignal() *workerSignal { return new(workerSignal) }
// MarkCancelled marks that a context cancellation was detected.
func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) }
// WasCancelled returns true if context cancellation was detected.
func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() }
// pageMetadata contains metadata about a single page of S3 objects being scanned.
type pageMetadata struct {
bucket string // The name of the S3 bucket being scanned
pageNumber int // Current page number in the pagination sequence
client *s3.S3 // AWS S3 client configured for the appropriate region
page *s3.ListObjectsV2Output // Contains the list of S3 objects in this page
}
// processingState tracks the state of concurrent S3 object processing.
type processingState struct {
errorCount *sync.Map // Thread-safe map tracking errors per prefix
objectCount *uint64 // Total number of objects processed
workerSignal *workerSignal // Coordinates cancellation across worker goroutines
}
func (s *Source) scanBuckets(
ctx context.Context,
client *s3.S3,
role string,
bucketsToScan []string,
chunksChan chan *sources.Chunk,
) {
var objectCount uint64
) error {
if role != "" {
ctx = context.WithValue(ctx, "role", role)
}
var objectCount uint64
for i, bucket := range bucketsToScan {
// Determine starting point for resuming scan.
resumePoint, err := s.progressTracker.GetResumePoint(ctx)
if err != nil {
return fmt.Errorf("failed to get resume point :%w", err)
}
startIdx, _ := slices.BinarySearch(bucketsToScan, resumePoint.CurrentBucket)
// Create worker signal to track cancellation across page processing.
workerSignal := newWorkerSignal()
bucketsToScanCount := len(bucketsToScan)
for i := startIdx; i < bucketsToScanCount; i++ {
bucket := bucketsToScan[i]
ctx := context.WithValue(ctx, "bucket", bucket)
if common.IsDone(ctx) {
return
return ctx.Err()
}
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
ctx.Logger().V(3).Info("Scanning bucket")
s.SetProgressComplete(
i,
len(bucketsToScan),
fmt.Sprintf("Bucket: %s", bucket),
s.Progress.EncodedResumeInfo, // Do not set, resume handled by progressTracker
)
regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
@ -230,13 +306,46 @@ func (s *Source) scanBuckets(
errorCount := sync.Map{}
input := &s3.ListObjectsV2Input{Bucket: &bucket}
if bucket == resumePoint.CurrentBucket && resumePoint.StartAfter != "" {
input.StartAfter = &resumePoint.StartAfter
ctx.Logger().V(3).Info(
"Resuming bucket scan",
"start_after", resumePoint.StartAfter,
)
}
pageNumber := 1
err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
ctx,
input,
func(page *s3.ListObjectsV2Output, _ bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: page,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: &objectCount,
workerSignal: workerSignal,
}
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)
if workerSignal.WasCancelled() {
return false // Stop pagination
}
pageNumber++
return true
})
// Check if we stopped due to cancellation.
if workerSignal.WasCancelled() {
return ctx.Err()
}
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
@ -249,18 +358,21 @@ func (s *Source) scanBuckets(
}
}
}
s.SetProgressComplete(
len(bucketsToScan),
len(bucketsToScan),
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
"",
)
return nil
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
}
return s.visitRoles(ctx, visitor)
@ -289,60 +401,73 @@ func (s *Source) getRegionalClientForBucket(
return regionalClient, nil
}
// pageChunker emits chunks onto the given channel from a page
// pageChunker emits chunks onto the given channel from a page.
func (s *Source) pageChunker(
ctx context.Context,
client *s3.S3,
metadata pageMetadata,
state processingState,
chunksChan chan *sources.Chunk,
bucket string,
page *s3.ListObjectsV2Output,
errorCount *sync.Map,
pageNumber int,
objectCount *uint64,
) {
for _, obj := range page.Contents {
s.progressTracker.Reset(ctx)
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)
for objIdx, obj := range metadata.page.Contents {
if obj == nil {
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for nil object")
}
continue
}
ctx = context.WithValues(
ctx,
"key", *obj.Key,
"bucket", bucket,
"page", pageNumber,
"size", *obj.Size,
)
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)
if common.IsDone(ctx) {
state.workerSignal.MarkCancelled()
return
}
// Skip GLACIER and GLACIER_IR objects.
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for glacier object")
}
continue
}
// Ignore large files.
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for large file")
}
continue
}
// File empty file.
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for empty file")
}
continue
}
// Skip incompatible extensions.
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for incompatible file")
}
continue
}
s.jobPool.Go(func() error {
defer common.RecoverWithExit(ctx)
if common.IsDone(ctx) {
state.workerSignal.MarkCancelled()
return ctx.Err()
}
if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
@ -352,7 +477,7 @@ func (s *Source) pageChunker(
path := strings.Split(*obj.Key, "/")
prefix := strings.Join(path[:len(path)-1], "/")
nErr, ok := errorCount.Load(prefix)
nErr, ok := state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
@ -366,8 +491,8 @@ func (s *Source) pageChunker(
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
defer cancel()
res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
Bucket: &bucket,
res, err := metadata.client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
Bucket: &metadata.bucket,
Key: obj.Key,
})
if err != nil {
@ -382,7 +507,7 @@ func (s *Source) pageChunker(
res.Body.Close()
}
nErr, ok := errorCount.Load(prefix)
nErr, ok := state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
@ -391,7 +516,7 @@ func (s *Source) pageChunker(
return nil
}
nErr = nErr.(int) + 1
errorCount.Store(prefix, nErr)
state.errorCount.Store(prefix, nErr)
// too many consecutive errors on this page
if nErr.(int) > 3 {
ctx.Logger().V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix)
@ -413,9 +538,9 @@ func (s *Source) pageChunker(
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_S3{
S3: &source_metadatapb.S3{
Bucket: bucket,
Bucket: metadata.bucket,
File: sanitizer.UTF8(*obj.Key),
Link: sanitizer.UTF8(makeS3Link(bucket, *client.Config.Region, *obj.Key)),
Link: sanitizer.UTF8(makeS3Link(metadata.bucket, *metadata.client.Config.Region, *obj.Key)),
Email: sanitizer.UTF8(email),
Timestamp: sanitizer.UTF8(modified),
},
@ -429,14 +554,19 @@ func (s *Source) pageChunker(
return nil
}
atomic.AddUint64(objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", objectCount)
nErr, ok = errorCount.Load(prefix)
atomic.AddUint64(state.objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
nErr, ok = state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
if nErr.(int) > 0 {
errorCount.Store(prefix, 0)
state.errorCount.Store(prefix, 0)
}
// Update progress after successful processing.
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}
return nil
@ -485,10 +615,13 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
// for each role, passing in the default S3 client, the role ARN, and the list of
// buckets to scan.
//
// The provided function parameter typically implements the core scanning logic
// and must handle context cancellation appropriately.
//
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
) error {
roles := s.conn.GetRoles()
if len(roles) == 0 {
@ -506,7 +639,9 @@ func (s *Source) visitRoles(
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
}
f(ctx, client, role, bucketsToScan)
if err := f(ctx, client, role, bucketsToScan); err != nil {
return err
}
}
return nil

View file

@ -4,18 +4,19 @@
package s3
import (
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
@ -82,6 +83,37 @@ func TestSource_ChunksLarge(t *testing.T) {
assert.Equal(t, got, wantChunkCount)
}
func TestSourceChunksNoResumption(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
s := Source{}
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test-2"},
}
conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
chunksCh := make(chan *sources.Chunk)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()
wantChunkCount := 19787
got := 0
for range chunksCh {
got++
}
assert.Equal(t, got, wantChunkCount)
}
func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
@ -216,3 +248,116 @@ func TestSource_Validate(t *testing.T) {
})
}
}
func TestSourceChunksResumption(t *testing.T) {
// First scan - simulate interruption.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
src := new(Source)
connection := &sourcespb.S3{
Credential: &sourcespb.S3_Unauthenticated{},
Buckets: []string{"trufflesec-ahrav-test-2"},
EnableResumption: true,
}
conn, err := anypb.New(connection)
require.NoError(t, err)
err = src.Init(ctx, "test name", 0, 0, false, conn, 2)
require.NoError(t, err)
chunksCh := make(chan *sources.Chunk)
var firstScanCount int64
const cancelAfterChunks = 15_000
cancelCtx, ctxCancel := context.WithCancel(ctx)
defer ctxCancel()
// Start first scan and collect chunks until chunk limit.
go func() {
defer close(chunksCh)
err = src.Chunks(cancelCtx, chunksCh)
assert.Error(t, err, "Expected context cancellation error")
}()
// Process chunks until we hit our limit
for range chunksCh {
firstScanCount++
if firstScanCount >= cancelAfterChunks {
ctxCancel() // Cancel context after processing desired number of chunks
break
}
}
// Verify we processed exactly the number of chunks we wanted.
assert.Equal(t, int64(cancelAfterChunks), firstScanCount,
"Should have processed exactly %d chunks in first scan", cancelAfterChunks)
// Verify we have processed some chunks and have resumption info.
assert.Greater(t, firstScanCount, int64(0), "Should have processed some chunks in first scan")
progress := src.GetProgress()
assert.NotEmpty(t, progress.EncodedResumeInfo, "Progress.EncodedResumeInfo should not be empty")
firstScanCompletedIndex := progress.SectionsCompleted
var resumeInfo ResumeInfo
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &resumeInfo)
require.NoError(t, err, "Should be able to decode resume info")
// Verify resume info contains expected fields.
assert.Equal(t, "trufflesec-ahrav-test-2", resumeInfo.CurrentBucket, "Resume info should contain correct bucket")
assert.NotEmpty(t, resumeInfo.StartAfter, "Resume info should contain a StartAfter key")
// Store the key where first scan stopped.
firstScanLastKey := resumeInfo.StartAfter
// Second scan - should resume from where first scan left off.
ctx2 := context.Background()
src2 := &Source{Progress: *src.GetProgress()}
err = src2.Init(ctx2, "test name", 0, 0, false, conn, 4)
require.NoError(t, err)
chunksCh2 := make(chan *sources.Chunk)
var secondScanCount int64
go func() {
defer close(chunksCh2)
err = src2.Chunks(ctx2, chunksCh2)
assert.NoError(t, err)
}()
// Process second scan chunks and verify progress.
for range chunksCh2 {
secondScanCount++
// Get current progress during scan.
currentProgress := src2.GetProgress()
assert.GreaterOrEqual(t, currentProgress.SectionsCompleted, firstScanCompletedIndex,
"Progress should be greater or equal to first scan")
if currentProgress.EncodedResumeInfo != "" {
var currentResumeInfo ResumeInfo
err := json.Unmarshal([]byte(currentProgress.EncodedResumeInfo), &currentResumeInfo)
require.NoError(t, err)
// Verify that we're always scanning forward from where we left off.
assert.GreaterOrEqual(t, currentResumeInfo.StartAfter, firstScanLastKey,
"Second scan should never process keys before where first scan ended")
}
}
// Verify total coverage.
expectedTotal := int64(19787)
actualTotal := firstScanCount + secondScanCount
// Because of our resumption logic favoring completeness over speed, we can
// re-scan some objects.
assert.GreaterOrEqual(t, actualTotal, expectedTotal,
"Total processed chunks should meet or exceed expected count")
assert.Less(t, actualTotal, 2*expectedTotal,
"Total processed chunks should not be more than double expected count")
finalProgress := src2.GetProgress()
assert.Equal(t, 1, int(finalProgress.SectionsCompleted), "Should have completed sections")
assert.Equal(t, 1, int(finalProgress.SectionsRemaining), "Should have remaining sections")
}

View file

@ -10,12 +10,13 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"google.golang.org/protobuf/types/known/anypb"
)
func TestSource_Init_IncludeAndIgnoreBucketsError(t *testing.T) {