mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-15 01:17:34 +00:00
cleanup
This commit is contained in:
parent
334a4f9ef2
commit
dd206643fe
5 changed files with 487 additions and 466 deletions
|
@ -21,9 +21,7 @@ type ProgressTracker struct {
|
|||
// completedObjects tracks which indices in the current page have been processed.
|
||||
sync.Mutex
|
||||
completedObjects []bool
|
||||
|
||||
baseCompleted int32 // Track completed count from previous pages
|
||||
currentPageSize int32 // Track the current page size to avoid double counting
|
||||
completionOrder []int // Track the order in which objects complete
|
||||
|
||||
// progress holds the scan's overall progress state and enables persistence.
|
||||
progress *sources.Progress // Reference to source's Progress
|
||||
|
@ -34,17 +32,18 @@ const defaultMaxObjectsPerPage = 1000
|
|||
// NewProgressTracker creates a new progress tracker for S3 scanning operations.
|
||||
// The enabled parameter determines if progress tracking is active, and progress
|
||||
// provides the underlying mechanism for persisting scan state.
|
||||
func NewProgressTracker(ctx context.Context, enabled bool, progress *sources.Progress) *ProgressTracker {
|
||||
func NewProgressTracker(_ context.Context, enabled bool, progress *sources.Progress) (*ProgressTracker, error) {
|
||||
if progress == nil {
|
||||
ctx.Logger().Info("Nil progress provided. Progress initialized.")
|
||||
progress = new(sources.Progress)
|
||||
return nil, errors.New("Nil progress provided; progress is required for tracking")
|
||||
}
|
||||
|
||||
return &ProgressTracker{
|
||||
// We are resuming if we have completed objects from a previous scan.
|
||||
completedObjects: make([]bool, defaultMaxObjectsPerPage),
|
||||
completionOrder: make([]int, 0, defaultMaxObjectsPerPage),
|
||||
enabled: enabled,
|
||||
progress: progress,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Reset prepares the tracker for a new page of objects by clearing the completion state.
|
||||
|
@ -56,9 +55,8 @@ func (p *ProgressTracker) Reset(_ context.Context) {
|
|||
p.Lock()
|
||||
defer p.Unlock()
|
||||
// Store the current completed count before moving to next page.
|
||||
p.baseCompleted = p.progress.SectionsCompleted
|
||||
p.currentPageSize = 0
|
||||
p.completedObjects = make([]bool, defaultMaxObjectsPerPage)
|
||||
p.completionOrder = make([]int, 0, defaultMaxObjectsPerPage)
|
||||
}
|
||||
|
||||
// ResumeInfo represents the state needed to resume an interrupted operation.
|
||||
|
@ -93,19 +91,12 @@ func (p *ProgressTracker) GetResumePoint(ctx context.Context) (ResumeInfo, error
|
|||
return resume, nil
|
||||
}
|
||||
|
||||
return ResumeInfo{
|
||||
CurrentBucket: resumeInfo.CurrentBucket,
|
||||
StartAfter: resumeInfo.StartAfter,
|
||||
}, nil
|
||||
return ResumeInfo{CurrentBucket: resumeInfo.CurrentBucket, StartAfter: resumeInfo.StartAfter}, nil
|
||||
}
|
||||
|
||||
// Complete marks the entire scanning operation as finished and clears the resume state.
|
||||
// This should only be called once all scanning operations are complete.
|
||||
func (p *ProgressTracker) Complete(_ context.Context, message string) error {
|
||||
if !p.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Preserve existing progress counters while clearing resume state.
|
||||
p.progress.SetProgressComplete(
|
||||
int(p.progress.SectionsCompleted),
|
||||
|
@ -126,8 +117,7 @@ func (p *ProgressTracker) Complete(_ context.Context, message string) error {
|
|||
//
|
||||
// This approach ensures scan reliability by only checkpointing consecutively completed
|
||||
// objects. While this may result in re-scanning some objects when resuming, it guarantees
|
||||
// no objects are missed in case of interruption. The linear search through page contents
|
||||
// is efficient given the fixed maximum page size of 1000 objects.
|
||||
// no objects are missed in case of interruption.
|
||||
func (p *ProgressTracker) UpdateObjectProgress(
|
||||
ctx context.Context,
|
||||
completedIdx int,
|
||||
|
@ -148,43 +138,47 @@ func (p *ProgressTracker) UpdateObjectProgress(
|
|||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Update remaining count only once per page.
|
||||
pageSize := int32(len(pageContents))
|
||||
if p.currentPageSize == 0 {
|
||||
p.progress.SectionsRemaining += pageSize
|
||||
p.currentPageSize = pageSize
|
||||
// Only track completion if this is the first time this index is marked complete.
|
||||
if !p.completedObjects[completedIdx] {
|
||||
p.completedObjects[completedIdx] = true
|
||||
p.completionOrder = append(p.completionOrder, completedIdx)
|
||||
}
|
||||
|
||||
p.completedObjects[completedIdx] = true
|
||||
// Find the highest safe checkpoint we can create.
|
||||
lastSafeIdx := -1
|
||||
var safeIndices [defaultMaxObjectsPerPage]bool
|
||||
|
||||
// Mark all completed indices.
|
||||
for _, idx := range p.completionOrder {
|
||||
safeIndices[idx] = true
|
||||
}
|
||||
|
||||
// Find the highest consecutive completed index.
|
||||
lastConsecutiveIdx := -1
|
||||
for i := 0; i <= completedIdx; i++ {
|
||||
if !p.completedObjects[i] {
|
||||
for i := range len(p.completedObjects) {
|
||||
if !safeIndices[i] {
|
||||
break
|
||||
}
|
||||
lastConsecutiveIdx = i
|
||||
lastSafeIdx = i
|
||||
}
|
||||
|
||||
// Update progress if we have at least one completed object.
|
||||
if lastConsecutiveIdx < 0 {
|
||||
if lastSafeIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
obj := pageContents[lastConsecutiveIdx]
|
||||
obj := pageContents[lastSafeIdx]
|
||||
info := &ResumeInfo{CurrentBucket: bucket, StartAfter: *obj.Key}
|
||||
encoded, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the total completed as base (from previous pages) plus consecutive completions in this page.
|
||||
completedCount := p.baseCompleted + int32(lastConsecutiveIdx+1)
|
||||
|
||||
// Purposefully avoid updating any progress counts.
|
||||
// Only update resume info.
|
||||
p.progress.SetProgressComplete(
|
||||
int(completedCount),
|
||||
int(p.progress.SectionsCompleted),
|
||||
int(p.progress.SectionsRemaining),
|
||||
fmt.Sprintf("Processing: %s/%s", bucket, *obj.Key),
|
||||
p.progress.Message,
|
||||
string(encoded),
|
||||
)
|
||||
return nil
|
||||
|
|
|
@ -7,11 +7,61 @@ import (
|
|||
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
)
|
||||
|
||||
func TestProgressTrackerResumption(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// First scan - process 6 objects then interrupt.
|
||||
initialProgress := &sources.Progress{}
|
||||
tracker, err := NewProgressTracker(ctx, true, initialProgress)
|
||||
require.NoError(t, err)
|
||||
|
||||
firstPage := &s3.ListObjectsV2Output{
|
||||
Contents: make([]*s3.Object, 12), // Total of 12 objects
|
||||
}
|
||||
for i := range 12 {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
firstPage.Contents[i] = &s3.Object{Key: &key}
|
||||
}
|
||||
|
||||
// Process first 6 objects.
|
||||
for i := range 6 {
|
||||
err := tracker.UpdateObjectProgress(ctx, i, "test-bucket", firstPage.Contents)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify resume info is set correctly.
|
||||
resumeInfo, err := tracker.GetResumePoint(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-bucket", resumeInfo.CurrentBucket)
|
||||
assert.Equal(t, "key-5", resumeInfo.StartAfter)
|
||||
|
||||
// Resume scan with existing progress.
|
||||
resumeTracker, err := NewProgressTracker(ctx, true, initialProgress)
|
||||
require.NoError(t, err)
|
||||
|
||||
resumePage := &s3.ListObjectsV2Output{
|
||||
Contents: firstPage.Contents[6:], // Remaining 6 objects
|
||||
}
|
||||
|
||||
// Process remaining objects.
|
||||
for i := range len(resumePage.Contents) {
|
||||
err := resumeTracker.UpdateObjectProgress(ctx, i, "test-bucket", resumePage.Contents)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify final resume info.
|
||||
finalResumeInfo, err := resumeTracker.GetResumePoint(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-bucket", finalResumeInfo.CurrentBucket)
|
||||
assert.Equal(t, "key-11", finalResumeInfo.StartAfter)
|
||||
}
|
||||
|
||||
func TestProgressTrackerReset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -19,7 +69,6 @@ func TestProgressTrackerReset(t *testing.T) {
|
|||
}{
|
||||
{name: "reset with enabled tracker", enabled: true},
|
||||
{name: "reset with disabled tracker", enabled: false},
|
||||
{name: "reset with zero capacity", enabled: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -28,19 +77,28 @@ func TestProgressTrackerReset(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.completedObjects[1] = true
|
||||
tracker.completedObjects[2] = true
|
||||
|
||||
tracker.Reset(ctx)
|
||||
|
||||
if !tt.enabled {
|
||||
assert.Equal(t, defaultMaxObjectsPerPage, len(tracker.completedObjects), "Reset did not clear completed objects")
|
||||
return
|
||||
}
|
||||
assert.Equal(t, defaultMaxObjectsPerPage, len(tracker.completedObjects),
|
||||
"Reset changed the length of completed objects")
|
||||
|
||||
assert.Equal(t, 0, len(tracker.completedObjects), "Reset did not clear completed objects")
|
||||
if tt.enabled {
|
||||
// All values should be false after reset.
|
||||
for i, isCompleted := range tracker.completedObjects {
|
||||
assert.False(t, isCompleted,
|
||||
"Reset did not clear completed object at index %d", i)
|
||||
}
|
||||
|
||||
// Completion order should be empty.
|
||||
assert.Equal(t, 0, len(tracker.completionOrder),
|
||||
"Reset did not clear completion order")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -116,239 +174,51 @@ func TestGetResumePoint(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func setupTestTracker(t *testing.T, enabled bool, progress *sources.Progress, pageSize int) (*ProgressTracker, *s3.ListObjectsV2Output) {
|
||||
t.Helper()
|
||||
|
||||
tracker := NewProgressTracker(context.Background(), enabled, progress)
|
||||
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, pageSize)}
|
||||
for i := range pageSize {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
page.Contents[i] = &s3.Object{Key: &key}
|
||||
}
|
||||
return tracker, page
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, false, progress, 5)
|
||||
|
||||
err := tracker.UpdateObjectProgress(context.Background(), 1, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Error updating progress when tracker disabled")
|
||||
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressCompletedIdxOOR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, 5)
|
||||
|
||||
err := tracker.UpdateObjectProgress(context.Background(), 1001, "test-bucket", page.Contents)
|
||||
assert.Error(t, err, "Expected error when completedIdx out of range")
|
||||
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
|
||||
}
|
||||
|
||||
func TestProgressTrackerSequence(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
// Each update is a sequence of {completedIdx, expectedCompleted, expectedRemaining}
|
||||
updates [][3]int
|
||||
pageSize int
|
||||
}{
|
||||
{
|
||||
name: "multiple updates same page",
|
||||
description: "Verify remaining count isn't doubled and completed accumulates correctly",
|
||||
pageSize: 5,
|
||||
updates: [][3]int{
|
||||
{0, 1, 5}, // First object - should set remaining to 5
|
||||
{1, 2, 5}, // Second object - remaining should stay 5
|
||||
{2, 3, 5}, // Third object - remaining should stay 5
|
||||
{4, 3, 5}, // Gap at index 3 - completed should stay at 3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "across page boundaries",
|
||||
description: "Verify completed count accumulates across pages",
|
||||
pageSize: 3,
|
||||
updates: [][3]int{
|
||||
// First page
|
||||
{0, 1, 3},
|
||||
{1, 2, 3},
|
||||
{2, 3, 3},
|
||||
// Reset and start new page.
|
||||
{0, 4, 6}, // baseCompleted(3) + current(1)
|
||||
{1, 5, 6}, // baseCompleted(3) + current(2)
|
||||
{2, 6, 6}, // baseCompleted(3) + current(3)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "incomplete page transition",
|
||||
description: "Verify incomplete page properly sets base completed",
|
||||
pageSize: 4,
|
||||
updates: [][3]int{
|
||||
// First page - only complete first 2.
|
||||
{0, 1, 4},
|
||||
{1, 2, 4},
|
||||
// Skip 2,3 and move to next page.
|
||||
// Reset and start new page.
|
||||
{0, 3, 8}, // baseCompleted(2) + current(1)
|
||||
{1, 4, 8}, // baseCompleted(2) + current(2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
|
||||
|
||||
pageCount := 0
|
||||
for i, update := range tt.updates {
|
||||
completedIdx, expectedCompleted, expectedRemaining := update[0], update[1], update[2]
|
||||
|
||||
// If this update starts a new page.
|
||||
if completedIdx == 0 && i > 0 {
|
||||
pageCount++
|
||||
tracker.Reset(ctx)
|
||||
// Create new page with same size.
|
||||
page = &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
|
||||
for j := range tt.pageSize {
|
||||
key := fmt.Sprintf("page%d-key-%d", pageCount, j)
|
||||
page.Contents[j] = &s3.Object{Key: &key}
|
||||
}
|
||||
}
|
||||
|
||||
err := tracker.UpdateObjectProgress(ctx, completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
|
||||
assert.Equal(t, expectedCompleted, int(progress.SectionsCompleted),
|
||||
"Incorrect completed count at update %d", i)
|
||||
assert.Equal(t, expectedRemaining, int(progress.SectionsRemaining),
|
||||
"Incorrect remaining count at update %d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
||||
func TestProgressTrackerUpdateProgress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
completedIdx int
|
||||
pageSize int
|
||||
preCompleted map[int]bool
|
||||
|
||||
expectedKey string
|
||||
expectedCompleted int
|
||||
expectedRemaining int
|
||||
preCompleted []int
|
||||
expectedKey string
|
||||
}{
|
||||
{
|
||||
name: "first object completed",
|
||||
description: "Basic case - completing first object",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1, // Only first object completed
|
||||
expectedRemaining: 3, // Total objects in page
|
||||
name: "first object completed",
|
||||
description: "Basic case - completing first object",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
expectedKey: "key-0",
|
||||
},
|
||||
{
|
||||
name: "completing missing middle",
|
||||
description: "Completing object when previous is done",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
preCompleted: map[int]bool{0: true},
|
||||
expectedKey: "key-1",
|
||||
expectedCompleted: 2, // First two objects completed
|
||||
expectedRemaining: 3, // Total objects in page
|
||||
name: "completing missing middle",
|
||||
description: "Completing object when previous is done",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0},
|
||||
expectedKey: "key-1",
|
||||
},
|
||||
{
|
||||
name: "completing first with last done",
|
||||
description: "Completing first object when last is already done",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
preCompleted: map[int]bool{2: true},
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1, // Only first object counts due to gap
|
||||
expectedRemaining: 3, // Total objects in page
|
||||
},
|
||||
{
|
||||
name: "all objects completed in order",
|
||||
description: "Completing final object in sequence",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
preCompleted: map[int]bool{0: true, 1: true},
|
||||
expectedKey: "key-2",
|
||||
expectedCompleted: 3, // All objects completed
|
||||
expectedRemaining: 3, // Total objects in page
|
||||
},
|
||||
{
|
||||
name: "completing middle gaps",
|
||||
description: "Completing object with gaps in sequence",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: map[int]bool{0: true, 1: true, 2: true, 4: true},
|
||||
expectedKey: "key-2", // Last consecutive completed
|
||||
expectedCompleted: 3, // Only first 3 count due to gap
|
||||
expectedRemaining: 10, // Total objects in page
|
||||
},
|
||||
{
|
||||
name: "zero index with empty pre-completed",
|
||||
description: "Edge case - minimum valid index",
|
||||
completedIdx: 0,
|
||||
pageSize: 1,
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1,
|
||||
expectedRemaining: 1,
|
||||
name: "all objects completed in order",
|
||||
description: "Completing final object in sequence",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0, 1},
|
||||
expectedKey: "key-2",
|
||||
},
|
||||
{
|
||||
name: "last index in max page",
|
||||
description: "Edge case - maximum page size boundary",
|
||||
completedIdx: 999,
|
||||
pageSize: 1000,
|
||||
preCompleted: func() map[int]bool {
|
||||
m := make(map[int]bool)
|
||||
preCompleted: func() []int {
|
||||
indices := make([]int, 999)
|
||||
for i := range 999 {
|
||||
m[i] = true
|
||||
indices[i] = i
|
||||
}
|
||||
return m
|
||||
return indices
|
||||
}(),
|
||||
expectedKey: "key-999",
|
||||
expectedCompleted: 1000, // All objects completed
|
||||
expectedRemaining: 1000, // Total objects in page
|
||||
},
|
||||
{
|
||||
name: "all previous completed",
|
||||
description: "Edge case - all previous indices completed",
|
||||
completedIdx: 100,
|
||||
pageSize: 101,
|
||||
preCompleted: func() map[int]bool {
|
||||
m := make(map[int]bool)
|
||||
for i := range 100 {
|
||||
m[i] = true
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
expectedKey: "key-100",
|
||||
expectedCompleted: 101, // All objects completed
|
||||
expectedRemaining: 101, // Total objects in page
|
||||
},
|
||||
{
|
||||
name: "large page number completion",
|
||||
description: "Edge case - very large page number",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: map[int]bool{0: true, 1: true, 2: true, 3: true, 4: true},
|
||||
expectedKey: "key-5",
|
||||
expectedCompleted: 6, // First 6 objects completed
|
||||
expectedRemaining: 10, // Total objects in page
|
||||
expectedKey: "key-999",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -358,108 +228,34 @@ func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
|
||||
tracker := &ProgressTracker{
|
||||
enabled: true,
|
||||
progress: progress,
|
||||
completedObjects: make([]bool, tt.pageSize),
|
||||
completionOrder: make([]int, 0, tt.pageSize),
|
||||
}
|
||||
|
||||
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
|
||||
for i := range tt.pageSize {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
page.Contents[i] = &s3.Object{Key: &key}
|
||||
}
|
||||
|
||||
// Apply pre-completed indices in order.
|
||||
if tt.preCompleted != nil {
|
||||
for k, v := range tt.preCompleted {
|
||||
tracker.completedObjects[k] = v
|
||||
for _, idx := range tt.preCompleted {
|
||||
tracker.completedObjects[idx] = true
|
||||
tracker.completionOrder = append(tracker.completionOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
|
||||
// Verify resume info.
|
||||
assert.NotEmpty(t, progress.EncodedResumeInfo, "Expected progress update")
|
||||
var info ResumeInfo
|
||||
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
|
||||
assert.NoError(t, err, "Failed to decode resume info")
|
||||
assert.Equal(t, tt.expectedKey, info.StartAfter, "Incorrect resume point")
|
||||
|
||||
// Verify progress counts.
|
||||
assert.Equal(t, tt.expectedCompleted, int(progress.SectionsCompleted),
|
||||
"Incorrect completed count")
|
||||
assert.Equal(t, tt.expectedRemaining, int(progress.SectionsRemaining),
|
||||
"Incorrect remaining count")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressNoResume(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
completedIdx int
|
||||
pageSize int
|
||||
preCompleted map[int]bool
|
||||
}{
|
||||
{
|
||||
name: "middle object completed first",
|
||||
description: "Basic case - completing middle object first",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
},
|
||||
{
|
||||
name: "last object completed first",
|
||||
description: "Basic case - completing last object first",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple gaps",
|
||||
description: "Multiple non-consecutive completions",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: map[int]bool{1: true, 3: true, 4: true},
|
||||
},
|
||||
{
|
||||
name: "alternating completion pattern",
|
||||
description: "Edge case - alternating completed/uncompleted pattern",
|
||||
completedIdx: 10,
|
||||
pageSize: 20,
|
||||
preCompleted: map[int]bool{2: true, 4: true, 6: true, 8: true},
|
||||
},
|
||||
{
|
||||
name: "sparse completion pattern",
|
||||
description: "Edge case - scattered completions with regular gaps",
|
||||
completedIdx: 50,
|
||||
pageSize: 100,
|
||||
preCompleted: map[int]bool{10: true, 20: true, 30: true, 40: true},
|
||||
},
|
||||
{
|
||||
name: "single gap breaks sequence",
|
||||
description: "Edge case - single gap prevents resume info",
|
||||
completedIdx: 50,
|
||||
pageSize: 100,
|
||||
preCompleted: func() map[int]bool {
|
||||
m := make(map[int]bool)
|
||||
for i := 1; i <= 49; i++ {
|
||||
m[i] = true
|
||||
}
|
||||
m[49] = false
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
enabled := tt.name != "disabled tracker"
|
||||
tracker, page := setupTestTracker(t, enabled, progress, tt.pageSize)
|
||||
|
||||
if tt.preCompleted != nil {
|
||||
for k, v := range tt.preCompleted {
|
||||
tracker.completedObjects[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Expected no progress update")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -469,98 +265,51 @@ func TestComplete(t *testing.T) {
|
|||
name string
|
||||
enabled bool
|
||||
initialState struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}
|
||||
completeMessage string
|
||||
wantState struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}
|
||||
}{
|
||||
{
|
||||
name: "marks completion with existing progress",
|
||||
name: "marks completion with existing resume info",
|
||||
enabled: true,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 100,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
|
||||
message: "In progress",
|
||||
resumeInfo: `{"current_bucket":"test-bucket","start_after":"some-key"}`,
|
||||
message: "In progress",
|
||||
},
|
||||
completeMessage: "Scan complete",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 100, // Should preserve existing progress
|
||||
sectionsRemaining: 100, // Should preserve existing progress
|
||||
resumeInfo: "", // Should clear resume info
|
||||
message: "Scan complete",
|
||||
resumeInfo: "", // Should clear resume info
|
||||
message: "Scan complete",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled tracker",
|
||||
enabled: false,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 50,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
|
||||
message: "Should not change",
|
||||
resumeInfo: "",
|
||||
message: "Should not change",
|
||||
},
|
||||
completeMessage: "Completed",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 50,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
|
||||
message: "Should not change",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completes with special characters",
|
||||
enabled: true,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 75,
|
||||
sectionsRemaining: 75,
|
||||
resumeInfo: `{"CurrentBucket":"bucket","StartAfter":"key"}`,
|
||||
message: "In progress",
|
||||
},
|
||||
completeMessage: "Completed scanning 特殊字符 & symbols !@#$%",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 75,
|
||||
sectionsRemaining: 75,
|
||||
resumeInfo: "",
|
||||
message: "Completed scanning 特殊字符 & symbols !@#$%",
|
||||
resumeInfo: "",
|
||||
message: "Completed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -571,18 +320,15 @@ func TestComplete(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := &sources.Progress{
|
||||
SectionsCompleted: int32(tt.initialState.sectionsCompleted),
|
||||
SectionsRemaining: int32(tt.initialState.sectionsRemaining),
|
||||
EncodedResumeInfo: tt.initialState.resumeInfo,
|
||||
Message: tt.initialState.message,
|
||||
}
|
||||
tracker := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
|
||||
err := tracker.Complete(ctx, tt.completeMessage)
|
||||
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = tracker.Complete(ctx, tt.completeMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(tt.wantState.sectionsCompleted), progress.SectionsCompleted)
|
||||
assert.Equal(t, int32(tt.wantState.sectionsRemaining), progress.SectionsRemaining)
|
||||
assert.Equal(t, tt.wantState.resumeInfo, progress.EncodedResumeInfo)
|
||||
assert.Equal(t, tt.wantState.message, progress.Message)
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ package s3
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -43,8 +44,10 @@ type Source struct {
|
|||
jobID sources.JobID
|
||||
verify bool
|
||||
concurrency int
|
||||
conn *sourcespb.S3
|
||||
|
||||
progressTracker *ProgressTracker
|
||||
sources.Progress
|
||||
conn *sourcespb.S3
|
||||
|
||||
errorCount *sync.Map
|
||||
jobPool *errgroup.Group
|
||||
|
@ -67,7 +70,7 @@ func (s *Source) JobID() sources.JobID { return s.jobID }
|
|||
|
||||
// Init returns an initialized AWS source
|
||||
func (s *Source) Init(
|
||||
_ context.Context,
|
||||
ctx context.Context,
|
||||
name string,
|
||||
jobID sources.JobID,
|
||||
sourceID sources.SourceID,
|
||||
|
@ -90,6 +93,12 @@ func (s *Source) Init(
|
|||
}
|
||||
s.conn = &conn
|
||||
|
||||
var err error
|
||||
s.progressTracker, err = NewProgressTracker(ctx, conn.GetEnableResumption(), &s.Progress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setMaxObjectSize(conn.GetMaxObjectSize())
|
||||
|
||||
if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
|
||||
|
@ -101,11 +110,12 @@ func (s *Source) Init(
|
|||
|
||||
func (s *Source) Validate(ctx context.Context) []error {
|
||||
var errs []error
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
|
||||
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
|
||||
if len(roleErrs) > 0 {
|
||||
errs = append(errs, roleErrs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.visitRoles(ctx, visitor); err != nil {
|
||||
|
@ -173,9 +183,16 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {
|
|||
return s3.New(sess), nil
|
||||
}
|
||||
|
||||
// IAM identity needs s3:ListBuckets permission
|
||||
// getBucketsToScan returns a list of S3 buckets to scan.
|
||||
// If the connection has a list of buckets specified, those are returned.
|
||||
// Otherwise, it lists all buckets the client has access to and filters out the ignored ones.
|
||||
// The list of buckets is sorted lexicographically to ensure consistent ordering,
|
||||
// which allows resuming scanning from the same place if the scan is interrupted.
|
||||
//
|
||||
// Note: The IAM identity needs the s3:ListBuckets permission.
|
||||
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
|
||||
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
|
||||
slices.Sort(buckets)
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
|
@ -196,32 +213,91 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
|
|||
bucketsToScan = append(bucketsToScan, name)
|
||||
}
|
||||
}
|
||||
slices.Sort(bucketsToScan)
|
||||
|
||||
return bucketsToScan, nil
|
||||
}
|
||||
|
||||
// workerSignal provides thread-safe tracking of cancellation state across multiple
|
||||
// goroutines processing S3 bucket pages. It ensures graceful shutdown when the context
|
||||
// is cancelled during bucket scanning operations.
|
||||
//
|
||||
// This type serves several key purposes:
|
||||
// 1. AWS ListObjectsV2PagesWithContext requires a callback that can only return bool,
|
||||
// not error. workerSignal bridges this gap by providing a way to communicate
|
||||
// cancellation back to the caller.
|
||||
// 2. The pageChunker spawns multiple concurrent workers to process objects within
|
||||
// each page. workerSignal enables these workers to detect and respond to
|
||||
// cancellation signals.
|
||||
// 3. Ensures proper progress tracking by allowing the main scanning loop to detect
|
||||
// when workers have been cancelled and handle cleanup appropriately.
|
||||
type workerSignal struct{ cancelled atomic.Bool }
|
||||
|
||||
// newWorkerSignal creates a new workerSignal
|
||||
func newWorkerSignal() *workerSignal { return new(workerSignal) }
|
||||
|
||||
// MarkCancelled marks that a context cancellation was detected.
|
||||
func (ws *workerSignal) MarkCancelled() { ws.cancelled.Store(true) }
|
||||
|
||||
// WasCancelled returns true if context cancellation was detected.
|
||||
func (ws *workerSignal) WasCancelled() bool { return ws.cancelled.Load() }
|
||||
|
||||
// pageMetadata contains metadata about a single page of S3 objects being scanned.
|
||||
type pageMetadata struct {
|
||||
bucket string // The name of the S3 bucket being scanned
|
||||
pageNumber int // Current page number in the pagination sequence
|
||||
client *s3.S3 // AWS S3 client configured for the appropriate region
|
||||
page *s3.ListObjectsV2Output // Contains the list of S3 objects in this page
|
||||
}
|
||||
|
||||
// processingState tracks the state of concurrent S3 object processing.
|
||||
type processingState struct {
|
||||
errorCount *sync.Map // Thread-safe map tracking errors per prefix
|
||||
objectCount *uint64 // Total number of objects processed
|
||||
workerSignal *workerSignal // Coordinates cancellation across worker goroutines
|
||||
}
|
||||
|
||||
func (s *Source) scanBuckets(
|
||||
ctx context.Context,
|
||||
client *s3.S3,
|
||||
role string,
|
||||
bucketsToScan []string,
|
||||
chunksChan chan *sources.Chunk,
|
||||
) {
|
||||
var objectCount uint64
|
||||
|
||||
) error {
|
||||
if role != "" {
|
||||
ctx = context.WithValue(ctx, "role", role)
|
||||
}
|
||||
var objectCount uint64
|
||||
|
||||
for i, bucket := range bucketsToScan {
|
||||
// Determine starting point for resuming scan.
|
||||
resumePoint, err := s.progressTracker.GetResumePoint(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resume point :%w", err)
|
||||
}
|
||||
|
||||
startIdx, _ := slices.BinarySearch(bucketsToScan, resumePoint.CurrentBucket)
|
||||
|
||||
// Create worker signal to track cancellation across page processing.
|
||||
workerSignal := newWorkerSignal()
|
||||
|
||||
bucketsToScanCount := len(bucketsToScan)
|
||||
for i := startIdx; i < bucketsToScanCount; i++ {
|
||||
bucket := bucketsToScan[i]
|
||||
ctx := context.WithValue(ctx, "bucket", bucket)
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
return
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
|
||||
ctx.Logger().V(3).Info("Scanning bucket")
|
||||
|
||||
s.SetProgressComplete(
|
||||
i,
|
||||
len(bucketsToScan),
|
||||
fmt.Sprintf("Bucket: %s", bucket),
|
||||
s.Progress.EncodedResumeInfo, // Do not set, resume handled by progressTracker
|
||||
)
|
||||
|
||||
regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
|
||||
if err != nil {
|
||||
ctx.Logger().Error(err, "could not get regional client for bucket")
|
||||
|
@ -230,13 +306,46 @@ func (s *Source) scanBuckets(
|
|||
|
||||
errorCount := sync.Map{}
|
||||
|
||||
input := &s3.ListObjectsV2Input{Bucket: &bucket}
|
||||
if bucket == resumePoint.CurrentBucket && resumePoint.StartAfter != "" {
|
||||
input.StartAfter = &resumePoint.StartAfter
|
||||
ctx.Logger().V(3).Info(
|
||||
"Resuming bucket scan",
|
||||
"start_after", resumePoint.StartAfter,
|
||||
)
|
||||
}
|
||||
|
||||
pageNumber := 1
|
||||
err = regionalClient.ListObjectsV2PagesWithContext(
|
||||
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
|
||||
ctx,
|
||||
input,
|
||||
func(page *s3.ListObjectsV2Output, _ bool) bool {
|
||||
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
|
||||
pageMetadata := pageMetadata{
|
||||
bucket: bucket,
|
||||
pageNumber: pageNumber,
|
||||
client: regionalClient,
|
||||
page: page,
|
||||
}
|
||||
processingState := processingState{
|
||||
errorCount: &errorCount,
|
||||
objectCount: &objectCount,
|
||||
workerSignal: workerSignal,
|
||||
}
|
||||
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)
|
||||
|
||||
if workerSignal.WasCancelled() {
|
||||
return false // Stop pagination
|
||||
}
|
||||
|
||||
pageNumber++
|
||||
return true
|
||||
})
|
||||
|
||||
// Check if we stopped due to cancellation.
|
||||
if workerSignal.WasCancelled() {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if role == "" {
|
||||
ctx.Logger().Error(err, "could not list objects in bucket")
|
||||
|
@ -249,18 +358,21 @@ func (s *Source) scanBuckets(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.SetProgressComplete(
|
||||
len(bucketsToScan),
|
||||
len(bucketsToScan),
|
||||
fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount),
|
||||
"",
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error {
|
||||
return s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
|
||||
}
|
||||
|
||||
return s.visitRoles(ctx, visitor)
|
||||
|
@ -289,60 +401,73 @@ func (s *Source) getRegionalClientForBucket(
|
|||
return regionalClient, nil
|
||||
}
|
||||
|
||||
// pageChunker emits chunks onto the given channel from a page
|
||||
// pageChunker emits chunks onto the given channel from a page.
|
||||
func (s *Source) pageChunker(
|
||||
ctx context.Context,
|
||||
client *s3.S3,
|
||||
metadata pageMetadata,
|
||||
state processingState,
|
||||
chunksChan chan *sources.Chunk,
|
||||
bucket string,
|
||||
page *s3.ListObjectsV2Output,
|
||||
errorCount *sync.Map,
|
||||
pageNumber int,
|
||||
objectCount *uint64,
|
||||
) {
|
||||
for _, obj := range page.Contents {
|
||||
s.progressTracker.Reset(ctx)
|
||||
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)
|
||||
|
||||
for objIdx, obj := range metadata.page.Contents {
|
||||
if obj == nil {
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for nil object")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ctx = context.WithValues(
|
||||
ctx,
|
||||
"key", *obj.Key,
|
||||
"bucket", bucket,
|
||||
"page", pageNumber,
|
||||
"size", *obj.Size,
|
||||
)
|
||||
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
state.workerSignal.MarkCancelled()
|
||||
return
|
||||
}
|
||||
|
||||
// Skip GLACIER and GLACIER_IR objects.
|
||||
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
|
||||
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for glacier object")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore large files.
|
||||
if *obj.Size > s.maxObjectSize {
|
||||
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for large file")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// File empty file.
|
||||
if *obj.Size == 0 {
|
||||
ctx.Logger().V(5).Info("Skipping empty file")
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for empty file")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip incompatible extensions.
|
||||
if common.SkipFile(*obj.Key) {
|
||||
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for incompatible file")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
s.jobPool.Go(func() error {
|
||||
defer common.RecoverWithExit(ctx)
|
||||
if common.IsDone(ctx) {
|
||||
state.workerSignal.MarkCancelled()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if strings.HasSuffix(*obj.Key, "/") {
|
||||
ctx.Logger().V(5).Info("Skipping directory")
|
||||
|
@ -352,7 +477,7 @@ func (s *Source) pageChunker(
|
|||
path := strings.Split(*obj.Key, "/")
|
||||
prefix := strings.Join(path[:len(path)-1], "/")
|
||||
|
||||
nErr, ok := errorCount.Load(prefix)
|
||||
nErr, ok := state.errorCount.Load(prefix)
|
||||
if !ok {
|
||||
nErr = 0
|
||||
}
|
||||
|
@ -366,8 +491,8 @@ func (s *Source) pageChunker(
|
|||
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
|
||||
Bucket: &bucket,
|
||||
res, err := metadata.client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
|
||||
Bucket: &metadata.bucket,
|
||||
Key: obj.Key,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -382,7 +507,7 @@ func (s *Source) pageChunker(
|
|||
res.Body.Close()
|
||||
}
|
||||
|
||||
nErr, ok := errorCount.Load(prefix)
|
||||
nErr, ok := state.errorCount.Load(prefix)
|
||||
if !ok {
|
||||
nErr = 0
|
||||
}
|
||||
|
@ -391,7 +516,7 @@ func (s *Source) pageChunker(
|
|||
return nil
|
||||
}
|
||||
nErr = nErr.(int) + 1
|
||||
errorCount.Store(prefix, nErr)
|
||||
state.errorCount.Store(prefix, nErr)
|
||||
// too many consecutive errors on this page
|
||||
if nErr.(int) > 3 {
|
||||
ctx.Logger().V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix)
|
||||
|
@ -413,9 +538,9 @@ func (s *Source) pageChunker(
|
|||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_S3{
|
||||
S3: &source_metadatapb.S3{
|
||||
Bucket: bucket,
|
||||
Bucket: metadata.bucket,
|
||||
File: sanitizer.UTF8(*obj.Key),
|
||||
Link: sanitizer.UTF8(makeS3Link(bucket, *client.Config.Region, *obj.Key)),
|
||||
Link: sanitizer.UTF8(makeS3Link(metadata.bucket, *metadata.client.Config.Region, *obj.Key)),
|
||||
Email: sanitizer.UTF8(email),
|
||||
Timestamp: sanitizer.UTF8(modified),
|
||||
},
|
||||
|
@ -429,14 +554,19 @@ func (s *Source) pageChunker(
|
|||
return nil
|
||||
}
|
||||
|
||||
atomic.AddUint64(objectCount, 1)
|
||||
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", objectCount)
|
||||
nErr, ok = errorCount.Load(prefix)
|
||||
atomic.AddUint64(state.objectCount, 1)
|
||||
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
|
||||
nErr, ok = state.errorCount.Load(prefix)
|
||||
if !ok {
|
||||
nErr = 0
|
||||
}
|
||||
if nErr.(int) > 0 {
|
||||
errorCount.Store(prefix, 0)
|
||||
state.errorCount.Store(prefix, 0)
|
||||
}
|
||||
|
||||
// Update progress after successful processing.
|
||||
if err := s.progressTracker.UpdateObjectProgress(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
|
||||
ctx.Logger().Error(err, "could not update progress for scanned object")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -485,10 +615,13 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
|
|||
// for each role, passing in the default S3 client, the role ARN, and the list of
|
||||
// buckets to scan.
|
||||
//
|
||||
// The provided function parameter typically implements the core scanning logic
|
||||
// and must handle context cancellation appropriately.
|
||||
//
|
||||
// If no roles are configured, it will call the function with an empty role ARN.
|
||||
func (s *Source) visitRoles(
|
||||
ctx context.Context,
|
||||
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string),
|
||||
f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) error,
|
||||
) error {
|
||||
roles := s.conn.GetRoles()
|
||||
if len(roles) == 0 {
|
||||
|
@ -506,7 +639,9 @@ func (s *Source) visitRoles(
|
|||
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
|
||||
}
|
||||
|
||||
f(ctx, client, role, bucketsToScan)
|
||||
if err := f(ctx, client, role, bucketsToScan); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -4,18 +4,19 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
)
|
||||
|
@ -82,6 +83,37 @@ func TestSource_ChunksLarge(t *testing.T) {
|
|||
assert.Equal(t, got, wantChunkCount)
|
||||
}
|
||||
|
||||
func TestSourceChunksNoResumption(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
s := Source{}
|
||||
connection := &sourcespb.S3{
|
||||
Credential: &sourcespb.S3_Unauthenticated{},
|
||||
Buckets: []string{"trufflesec-ahrav-test-2"},
|
||||
}
|
||||
conn, err := anypb.New(connection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = s.Init(ctx, "test name", 0, 0, false, conn, 1)
|
||||
chunksCh := make(chan *sources.Chunk)
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err = s.Chunks(ctx, chunksCh)
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
|
||||
wantChunkCount := 19787
|
||||
got := 0
|
||||
|
||||
for range chunksCh {
|
||||
got++
|
||||
}
|
||||
assert.Equal(t, got, wantChunkCount)
|
||||
}
|
||||
|
||||
func TestSource_Validate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
|
||||
defer cancel()
|
||||
|
@ -216,3 +248,116 @@ func TestSource_Validate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceChunksResumption(t *testing.T) {
|
||||
// First scan - simulate interruption.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
src := new(Source)
|
||||
connection := &sourcespb.S3{
|
||||
Credential: &sourcespb.S3_Unauthenticated{},
|
||||
Buckets: []string{"trufflesec-ahrav-test-2"},
|
||||
EnableResumption: true,
|
||||
}
|
||||
conn, err := anypb.New(connection)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = src.Init(ctx, "test name", 0, 0, false, conn, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunksCh := make(chan *sources.Chunk)
|
||||
var firstScanCount int64
|
||||
const cancelAfterChunks = 15_000
|
||||
|
||||
cancelCtx, ctxCancel := context.WithCancel(ctx)
|
||||
defer ctxCancel()
|
||||
|
||||
// Start first scan and collect chunks until chunk limit.
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err = src.Chunks(cancelCtx, chunksCh)
|
||||
assert.Error(t, err, "Expected context cancellation error")
|
||||
}()
|
||||
|
||||
// Process chunks until we hit our limit
|
||||
for range chunksCh {
|
||||
firstScanCount++
|
||||
if firstScanCount >= cancelAfterChunks {
|
||||
ctxCancel() // Cancel context after processing desired number of chunks
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we processed exactly the number of chunks we wanted.
|
||||
assert.Equal(t, int64(cancelAfterChunks), firstScanCount,
|
||||
"Should have processed exactly %d chunks in first scan", cancelAfterChunks)
|
||||
|
||||
// Verify we have processed some chunks and have resumption info.
|
||||
assert.Greater(t, firstScanCount, int64(0), "Should have processed some chunks in first scan")
|
||||
|
||||
progress := src.GetProgress()
|
||||
assert.NotEmpty(t, progress.EncodedResumeInfo, "Progress.EncodedResumeInfo should not be empty")
|
||||
|
||||
firstScanCompletedIndex := progress.SectionsCompleted
|
||||
|
||||
var resumeInfo ResumeInfo
|
||||
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &resumeInfo)
|
||||
require.NoError(t, err, "Should be able to decode resume info")
|
||||
|
||||
// Verify resume info contains expected fields.
|
||||
assert.Equal(t, "trufflesec-ahrav-test-2", resumeInfo.CurrentBucket, "Resume info should contain correct bucket")
|
||||
assert.NotEmpty(t, resumeInfo.StartAfter, "Resume info should contain a StartAfter key")
|
||||
|
||||
// Store the key where first scan stopped.
|
||||
firstScanLastKey := resumeInfo.StartAfter
|
||||
|
||||
// Second scan - should resume from where first scan left off.
|
||||
ctx2 := context.Background()
|
||||
src2 := &Source{Progress: *src.GetProgress()}
|
||||
err = src2.Init(ctx2, "test name", 0, 0, false, conn, 4)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunksCh2 := make(chan *sources.Chunk)
|
||||
var secondScanCount int64
|
||||
|
||||
go func() {
|
||||
defer close(chunksCh2)
|
||||
err = src2.Chunks(ctx2, chunksCh2)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Process second scan chunks and verify progress.
|
||||
for range chunksCh2 {
|
||||
secondScanCount++
|
||||
|
||||
// Get current progress during scan.
|
||||
currentProgress := src2.GetProgress()
|
||||
assert.GreaterOrEqual(t, currentProgress.SectionsCompleted, firstScanCompletedIndex,
|
||||
"Progress should be greater or equal to first scan")
|
||||
if currentProgress.EncodedResumeInfo != "" {
|
||||
var currentResumeInfo ResumeInfo
|
||||
err := json.Unmarshal([]byte(currentProgress.EncodedResumeInfo), ¤tResumeInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that we're always scanning forward from where we left off.
|
||||
assert.GreaterOrEqual(t, currentResumeInfo.StartAfter, firstScanLastKey,
|
||||
"Second scan should never process keys before where first scan ended")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify total coverage.
|
||||
expectedTotal := int64(19787)
|
||||
actualTotal := firstScanCount + secondScanCount
|
||||
|
||||
// Because of our resumption logic favoring completeness over speed, we can
|
||||
// re-scan some objects.
|
||||
assert.GreaterOrEqual(t, actualTotal, expectedTotal,
|
||||
"Total processed chunks should meet or exceed expected count")
|
||||
assert.Less(t, actualTotal, 2*expectedTotal,
|
||||
"Total processed chunks should not be more than double expected count")
|
||||
|
||||
finalProgress := src2.GetProgress()
|
||||
assert.Equal(t, 1, int(finalProgress.SectionsCompleted), "Should have completed sections")
|
||||
assert.Equal(t, 1, int(finalProgress.SectionsRemaining), "Should have remaining sections")
|
||||
}
|
||||
|
|
|
@ -10,12 +10,13 @@ import (
|
|||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func TestSource_Init_IncludeAndIgnoreBucketsError(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue