mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-13 00:17:18 +00:00
remove progress tracking
This commit is contained in:
parent
627ece0bc8
commit
0047c62439
2 changed files with 84 additions and 418 deletions
|
@ -18,13 +18,10 @@ import (
|
|||
type ProgressTracker struct {
|
||||
enabled bool
|
||||
|
||||
isResuming bool // Indicates if the tracker is resuming a previous scan
|
||||
// completedObjects tracks which indices in the current page have been processed.
|
||||
sync.Mutex
|
||||
completedObjects []bool
|
||||
completionOrder []int // Track the order in which objects complete
|
||||
baseCompleted int32 // Track completed count from previous pages
|
||||
currentPageSize int32 // Track the current page size to avoid double counting
|
||||
|
||||
// progress holds the scan's overall progress state and enables persistence.
|
||||
progress *sources.Progress // Reference to source's Progress
|
||||
|
@ -35,21 +32,18 @@ const defaultMaxObjectsPerPage = 1000
|
|||
// NewProgressTracker creates a new progress tracker for S3 scanning operations.
|
||||
// The enabled parameter determines if progress tracking is active, and progress
|
||||
// provides the underlying mechanism for persisting scan state.
|
||||
func NewProgressTracker(ctx context.Context, enabled bool, progress *sources.Progress) *ProgressTracker {
|
||||
func NewProgressTracker(_ context.Context, enabled bool, progress *sources.Progress) (*ProgressTracker, error) {
|
||||
if progress == nil {
|
||||
ctx.Logger().Info("Nil progress provided. Progress initialized.")
|
||||
progress = new(sources.Progress)
|
||||
return nil, errors.New("Nil progress provided; progress is required for tracking")
|
||||
}
|
||||
|
||||
return &ProgressTracker{
|
||||
// We are resuming if we have completed objects from a previous scan.
|
||||
isResuming: progress.SectionsCompleted > 0 || progress.SectionsRemaining > 0,
|
||||
completedObjects: make([]bool, defaultMaxObjectsPerPage),
|
||||
completionOrder: make([]int, 0, defaultMaxObjectsPerPage),
|
||||
enabled: enabled,
|
||||
progress: progress,
|
||||
baseCompleted: progress.SectionsCompleted,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Reset prepares the tracker for a new page of objects by clearing the completion state.
|
||||
|
@ -61,8 +55,6 @@ func (p *ProgressTracker) Reset(_ context.Context) {
|
|||
p.Lock()
|
||||
defer p.Unlock()
|
||||
// Store the current completed count before moving to next page.
|
||||
p.baseCompleted = p.progress.SectionsCompleted
|
||||
p.currentPageSize = 0
|
||||
p.completedObjects = make([]bool, defaultMaxObjectsPerPage)
|
||||
p.completionOrder = make([]int, 0, defaultMaxObjectsPerPage)
|
||||
}
|
||||
|
@ -99,10 +91,7 @@ func (p *ProgressTracker) GetResumePoint(ctx context.Context) (ResumeInfo, error
|
|||
return resume, nil
|
||||
}
|
||||
|
||||
return ResumeInfo{
|
||||
CurrentBucket: resumeInfo.CurrentBucket,
|
||||
StartAfter: resumeInfo.StartAfter,
|
||||
}, nil
|
||||
return ResumeInfo{CurrentBucket: resumeInfo.CurrentBucket, StartAfter: resumeInfo.StartAfter}, nil
|
||||
}
|
||||
|
||||
// Complete marks the entire scanning operation as finished and clears the resume state.
|
||||
|
@ -128,8 +117,7 @@ func (p *ProgressTracker) Complete(_ context.Context, message string) error {
|
|||
//
|
||||
// This approach ensures scan reliability by only checkpointing consecutively completed
|
||||
// objects. While this may result in re-scanning some objects when resuming, it guarantees
|
||||
// no objects are missed in case of interruption. The linear search through page contents
|
||||
// is efficient given the fixed maximum page size of 1000 objects.
|
||||
// no objects are missed in case of interruption.
|
||||
func (p *ProgressTracker) UpdateObjectProgress(
|
||||
ctx context.Context,
|
||||
completedIdx int,
|
||||
|
@ -150,16 +138,6 @@ func (p *ProgressTracker) UpdateObjectProgress(
|
|||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Update remaining count only once per page.
|
||||
pageSize := int32(len(pageContents))
|
||||
if p.currentPageSize == 0 {
|
||||
if !p.isResuming {
|
||||
// Only add to the total if this is a fresh scan.
|
||||
p.progress.SectionsRemaining += pageSize
|
||||
}
|
||||
p.currentPageSize = pageSize
|
||||
}
|
||||
|
||||
// Only track completion if this is the first time this index is marked complete.
|
||||
if !p.completedObjects[completedIdx] {
|
||||
p.completedObjects[completedIdx] = true
|
||||
|
@ -195,13 +173,12 @@ func (p *ProgressTracker) UpdateObjectProgress(
|
|||
return err
|
||||
}
|
||||
|
||||
// Set the total completed as base plus consecutive completions.
|
||||
completedCount := p.baseCompleted + int32(lastSafeIdx+1)
|
||||
|
||||
// Purposefully avoid updating any progress counts.
|
||||
// Only update resume info.
|
||||
p.progress.SetProgressComplete(
|
||||
int(completedCount),
|
||||
int(p.progress.SectionsCompleted),
|
||||
int(p.progress.SectionsRemaining),
|
||||
fmt.Sprintf("Processing: %s/%s", bucket, *obj.Key),
|
||||
p.progress.Message,
|
||||
string(encoded),
|
||||
)
|
||||
return nil
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
|
@ -17,7 +18,8 @@ func TestProgressTrackerResumption(t *testing.T) {
|
|||
|
||||
// First scan - process 6 objects then interrupt.
|
||||
initialProgress := &sources.Progress{}
|
||||
tracker := NewProgressTracker(ctx, true, initialProgress)
|
||||
tracker, err := NewProgressTracker(ctx, true, initialProgress)
|
||||
require.NoError(t, err)
|
||||
|
||||
firstPage := &s3.ListObjectsV2Output{
|
||||
Contents: make([]*s3.Object, 12), // Total of 12 objects
|
||||
|
@ -33,12 +35,16 @@ func TestProgressTrackerResumption(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify state after the intial scan.
|
||||
assert.Equal(t, 6, int(initialProgress.SectionsCompleted), "Should have 6 completed")
|
||||
assert.Equal(t, 12, int(initialProgress.SectionsRemaining), "Should have 12 total")
|
||||
// Verify resume info is set correctly.
|
||||
resumeInfo, err := tracker.GetResumePoint(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-bucket", resumeInfo.CurrentBucket)
|
||||
assert.Equal(t, "key-5", resumeInfo.StartAfter)
|
||||
|
||||
// Resume scan with existing progress.
|
||||
resumeTracker := NewProgressTracker(ctx, true, initialProgress)
|
||||
resumeTracker, err := NewProgressTracker(ctx, true, initialProgress)
|
||||
require.NoError(t, err)
|
||||
|
||||
resumePage := &s3.ListObjectsV2Output{
|
||||
Contents: firstPage.Contents[6:], // Remaining 6 objects
|
||||
}
|
||||
|
@ -49,10 +55,11 @@ func TestProgressTrackerResumption(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 12, int(initialProgress.SectionsCompleted),
|
||||
"Should have 12 total completed sections")
|
||||
assert.Equal(t, 12, int(initialProgress.SectionsRemaining),
|
||||
"Should have 12 total sections")
|
||||
// Verify final resume info.
|
||||
finalResumeInfo, err := resumeTracker.GetResumePoint(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-bucket", finalResumeInfo.CurrentBucket)
|
||||
assert.Equal(t, "key-11", finalResumeInfo.StartAfter)
|
||||
}
|
||||
|
||||
func TestProgressTrackerReset(t *testing.T) {
|
||||
|
@ -62,7 +69,6 @@ func TestProgressTrackerReset(t *testing.T) {
|
|||
}{
|
||||
{name: "reset with enabled tracker", enabled: true},
|
||||
{name: "reset with disabled tracker", enabled: false},
|
||||
{name: "reset with zero capacity", enabled: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -71,14 +77,14 @@ func TestProgressTrackerReset(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.completedObjects[1] = true
|
||||
tracker.completedObjects[2] = true
|
||||
|
||||
tracker.Reset(ctx)
|
||||
|
||||
// Length should always be defaultMaxObjectsPerPage.
|
||||
assert.Equal(t, defaultMaxObjectsPerPage, len(tracker.completedObjects),
|
||||
"Reset changed the length of completed objects")
|
||||
|
||||
|
@ -168,202 +174,37 @@ func TestGetResumePoint(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func setupTestTracker(t *testing.T, enabled bool, progress *sources.Progress, pageSize int) (*ProgressTracker, *s3.ListObjectsV2Output) {
|
||||
t.Helper()
|
||||
|
||||
tracker := &ProgressTracker{
|
||||
enabled: enabled,
|
||||
progress: progress,
|
||||
completedObjects: make([]bool, pageSize),
|
||||
completionOrder: make([]int, 0, pageSize),
|
||||
}
|
||||
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, pageSize)}
|
||||
for i := range pageSize {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
page.Contents[i] = &s3.Object{Key: &key}
|
||||
}
|
||||
return tracker, page
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, false, progress, 5)
|
||||
|
||||
err := tracker.UpdateObjectProgress(context.Background(), 1, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Error updating progress when tracker disabled")
|
||||
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressCompletedIdxOOR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, 5)
|
||||
|
||||
err := tracker.UpdateObjectProgress(context.Background(), 1001, "test-bucket", page.Contents)
|
||||
assert.Error(t, err, "Expected error when completedIdx out of range")
|
||||
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Progress updated when tracker disabled")
|
||||
}
|
||||
|
||||
func TestProgressTrackerSequence(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
// Each update is a sequence of {completedIdx, expectedCompleted, expectedRemaining}
|
||||
updates [][3]int
|
||||
pageSize int
|
||||
}{
|
||||
{
|
||||
name: "multiple updates same page",
|
||||
description: "Verify remaining count isn't doubled and completed accumulates correctly",
|
||||
pageSize: 5,
|
||||
updates: [][3]int{
|
||||
{0, 1, 5}, // First object - should set remaining to 5
|
||||
{1, 2, 5}, // Second object - remaining should stay 5
|
||||
{2, 3, 5}, // Third object - remaining should stay 5
|
||||
{4, 3, 5}, // Gap at index 3 - completed should stay at 3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "across page boundaries",
|
||||
description: "Verify completed count accumulates across pages",
|
||||
pageSize: 3,
|
||||
updates: [][3]int{
|
||||
// First page
|
||||
{0, 1, 3},
|
||||
{1, 2, 3},
|
||||
{2, 3, 3},
|
||||
// Reset and start new page.
|
||||
{0, 4, 6}, // baseCompleted(3) + current(1)
|
||||
{1, 5, 6}, // baseCompleted(3) + current(2)
|
||||
{2, 6, 6}, // baseCompleted(3) + current(3)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "incomplete page transition",
|
||||
description: "Verify incomplete page properly sets base completed",
|
||||
pageSize: 4,
|
||||
updates: [][3]int{
|
||||
// First page - only complete first 2.
|
||||
{0, 1, 4},
|
||||
{1, 2, 4},
|
||||
// Skip 2,3 and move to next page.
|
||||
// Reset and start new page.
|
||||
{0, 3, 8}, // baseCompleted(2) + current(1)
|
||||
{1, 4, 8}, // baseCompleted(2) + current(2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
|
||||
|
||||
pageCount := 0
|
||||
for i, update := range tt.updates {
|
||||
completedIdx, expectedCompleted, expectedRemaining := update[0], update[1], update[2]
|
||||
|
||||
// If this update starts a new page.
|
||||
if completedIdx == 0 && i > 0 {
|
||||
pageCount++
|
||||
tracker.Reset(ctx)
|
||||
// Create new page with same size.
|
||||
page = &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
|
||||
for j := range tt.pageSize {
|
||||
key := fmt.Sprintf("page%d-key-%d", pageCount, j)
|
||||
page.Contents[j] = &s3.Object{Key: &key}
|
||||
}
|
||||
}
|
||||
|
||||
err := tracker.UpdateObjectProgress(ctx, completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
|
||||
assert.Equal(t, expectedCompleted, int(progress.SectionsCompleted),
|
||||
"Incorrect completed count at update %d", i)
|
||||
assert.Equal(t, expectedRemaining, int(progress.SectionsRemaining),
|
||||
"Incorrect remaining count at update %d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
||||
func TestProgressTrackerUpdateProgress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
completedIdx int
|
||||
pageSize int
|
||||
preCompleted []int
|
||||
|
||||
expectedKey string
|
||||
expectedCompleted int
|
||||
expectedRemaining int
|
||||
expectedKey string
|
||||
}{
|
||||
{
|
||||
name: "first object completed",
|
||||
description: "Basic case - completing first object",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1,
|
||||
expectedRemaining: 3,
|
||||
name: "first object completed",
|
||||
description: "Basic case - completing first object",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
expectedKey: "key-0",
|
||||
},
|
||||
{
|
||||
name: "completing missing middle",
|
||||
description: "Completing object when previous is done",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0},
|
||||
expectedKey: "key-1",
|
||||
expectedCompleted: 2,
|
||||
expectedRemaining: 3,
|
||||
name: "completing missing middle",
|
||||
description: "Completing object when previous is done",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0},
|
||||
expectedKey: "key-1",
|
||||
},
|
||||
{
|
||||
name: "completing first with last done",
|
||||
description: "Completing first object when last is already done",
|
||||
completedIdx: 0,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{2},
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1,
|
||||
expectedRemaining: 3,
|
||||
},
|
||||
{
|
||||
name: "all objects completed in order",
|
||||
description: "Completing final object in sequence",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0, 1},
|
||||
expectedKey: "key-2",
|
||||
expectedCompleted: 3,
|
||||
expectedRemaining: 3,
|
||||
},
|
||||
{
|
||||
name: "completing middle gaps",
|
||||
description: "Completing object with gaps in sequence",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: []int{0, 1, 2, 4},
|
||||
expectedKey: "key-2",
|
||||
expectedCompleted: 3,
|
||||
expectedRemaining: 10,
|
||||
},
|
||||
{
|
||||
name: "zero index with empty pre-completed",
|
||||
description: "Edge case - minimum valid index",
|
||||
completedIdx: 0,
|
||||
pageSize: 1,
|
||||
expectedKey: "key-0",
|
||||
expectedCompleted: 1,
|
||||
expectedRemaining: 1,
|
||||
name: "all objects completed in order",
|
||||
description: "Completing final object in sequence",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
preCompleted: []int{0, 1},
|
||||
expectedKey: "key-2",
|
||||
},
|
||||
{
|
||||
name: "last index in max page",
|
||||
|
@ -377,35 +218,7 @@ func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
|||
}
|
||||
return indices
|
||||
}(),
|
||||
expectedKey: "key-999",
|
||||
expectedCompleted: 1000,
|
||||
expectedRemaining: 1000,
|
||||
},
|
||||
{
|
||||
name: "all previous completed",
|
||||
description: "Edge case - all previous indices completed",
|
||||
completedIdx: 100,
|
||||
pageSize: 101,
|
||||
preCompleted: func() []int {
|
||||
indices := make([]int, 100)
|
||||
for i := range 100 {
|
||||
indices[i] = i
|
||||
}
|
||||
return indices
|
||||
}(),
|
||||
expectedKey: "key-100",
|
||||
expectedCompleted: 101,
|
||||
expectedRemaining: 101,
|
||||
},
|
||||
{
|
||||
name: "large page number completion",
|
||||
description: "Edge case - very large page number",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: []int{0, 1, 2, 3, 4},
|
||||
expectedKey: "key-5",
|
||||
expectedCompleted: 6,
|
||||
expectedRemaining: 10,
|
||||
expectedKey: "key-999",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -415,7 +228,18 @@ func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
tracker, page := setupTestTracker(t, true, progress, tt.pageSize)
|
||||
tracker := &ProgressTracker{
|
||||
enabled: true,
|
||||
progress: progress,
|
||||
completedObjects: make([]bool, tt.pageSize),
|
||||
completionOrder: make([]int, 0, tt.pageSize),
|
||||
}
|
||||
|
||||
page := &s3.ListObjectsV2Output{Contents: make([]*s3.Object, tt.pageSize)}
|
||||
for i := range tt.pageSize {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
page.Contents[i] = &s3.Object{Key: &key}
|
||||
}
|
||||
|
||||
// Apply pre-completed indices in order.
|
||||
if tt.preCompleted != nil {
|
||||
|
@ -428,95 +252,10 @@ func TestProgressTrackerUpdateProgressWithResume(t *testing.T) {
|
|||
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
|
||||
assert.NotEmpty(t, progress.EncodedResumeInfo, "Expected progress update")
|
||||
var info ResumeInfo
|
||||
err = json.Unmarshal([]byte(progress.EncodedResumeInfo), &info)
|
||||
assert.NoError(t, err, "Failed to decode resume info")
|
||||
assert.Equal(t, tt.expectedKey, info.StartAfter, "Incorrect resume point")
|
||||
|
||||
assert.Equal(t, tt.expectedCompleted, int(progress.SectionsCompleted),
|
||||
"Incorrect completed count")
|
||||
assert.Equal(t, tt.expectedRemaining, int(progress.SectionsRemaining),
|
||||
"Incorrect remaining count")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgressTrackerUpdateProgressNoResume(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
completedIdx int
|
||||
pageSize int
|
||||
preCompleted map[int]bool
|
||||
}{
|
||||
{
|
||||
name: "middle object completed first",
|
||||
description: "Basic case - completing middle object first",
|
||||
completedIdx: 1,
|
||||
pageSize: 3,
|
||||
},
|
||||
{
|
||||
name: "last object completed first",
|
||||
description: "Basic case - completing last object first",
|
||||
completedIdx: 2,
|
||||
pageSize: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple gaps",
|
||||
description: "Multiple non-consecutive completions",
|
||||
completedIdx: 5,
|
||||
pageSize: 10,
|
||||
preCompleted: map[int]bool{1: true, 3: true, 4: true},
|
||||
},
|
||||
{
|
||||
name: "alternating completion pattern",
|
||||
description: "Edge case - alternating completed/uncompleted pattern",
|
||||
completedIdx: 10,
|
||||
pageSize: 20,
|
||||
preCompleted: map[int]bool{2: true, 4: true, 6: true, 8: true},
|
||||
},
|
||||
{
|
||||
name: "sparse completion pattern",
|
||||
description: "Edge case - scattered completions with regular gaps",
|
||||
completedIdx: 50,
|
||||
pageSize: 100,
|
||||
preCompleted: map[int]bool{10: true, 20: true, 30: true, 40: true},
|
||||
},
|
||||
{
|
||||
name: "single gap breaks sequence",
|
||||
description: "Edge case - single gap prevents resume info",
|
||||
completedIdx: 50,
|
||||
pageSize: 100,
|
||||
preCompleted: func() map[int]bool {
|
||||
m := make(map[int]bool)
|
||||
for i := 1; i <= 49; i++ {
|
||||
m[i] = true
|
||||
}
|
||||
m[49] = false
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
progress := new(sources.Progress)
|
||||
enabled := tt.name != "disabled tracker"
|
||||
tracker, page := setupTestTracker(t, enabled, progress, tt.pageSize)
|
||||
|
||||
if tt.preCompleted != nil {
|
||||
for k, v := range tt.preCompleted {
|
||||
tracker.completedObjects[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
err := tracker.UpdateObjectProgress(ctx, tt.completedIdx, "test-bucket", page.Contents)
|
||||
assert.NoError(t, err, "Unexpected error updating progress")
|
||||
assert.Empty(t, progress.EncodedResumeInfo, "Expected no progress update")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -526,98 +265,51 @@ func TestComplete(t *testing.T) {
|
|||
name string
|
||||
enabled bool
|
||||
initialState struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}
|
||||
completeMessage string
|
||||
wantState struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}
|
||||
}{
|
||||
{
|
||||
name: "marks completion with existing progress",
|
||||
name: "marks completion with existing resume info",
|
||||
enabled: true,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 100,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: `{"CurrentBucket":"test-bucket","StartAfter":"some-key"}`,
|
||||
message: "In progress",
|
||||
resumeInfo: `{"current_bucket":"test-bucket","start_after":"some-key"}`,
|
||||
message: "In progress",
|
||||
},
|
||||
completeMessage: "Scan complete",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 100, // Should preserve existing progress
|
||||
sectionsRemaining: 100, // Should preserve existing progress
|
||||
resumeInfo: "", // Should clear resume info
|
||||
message: "Scan complete",
|
||||
resumeInfo: "", // Should clear resume info
|
||||
message: "Scan complete",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled tracker",
|
||||
enabled: false,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 50,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: "",
|
||||
message: "Should not change",
|
||||
resumeInfo: "",
|
||||
message: "Should not change",
|
||||
},
|
||||
completeMessage: "Completed",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 50,
|
||||
sectionsRemaining: 100,
|
||||
resumeInfo: "",
|
||||
message: "Completed",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completes with special characters",
|
||||
enabled: true,
|
||||
initialState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 75,
|
||||
sectionsRemaining: 75,
|
||||
resumeInfo: `{"CurrentBucket":"bucket","StartAfter":"key"}`,
|
||||
message: "In progress",
|
||||
},
|
||||
completeMessage: "Completed scanning 特殊字符 & symbols !@#$%",
|
||||
wantState: struct {
|
||||
sectionsCompleted uint64
|
||||
sectionsRemaining uint64
|
||||
resumeInfo string
|
||||
message string
|
||||
}{
|
||||
sectionsCompleted: 75,
|
||||
sectionsRemaining: 75,
|
||||
resumeInfo: "",
|
||||
message: "Completed scanning 特殊字符 & symbols !@#$%",
|
||||
resumeInfo: "",
|
||||
message: "Completed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -628,18 +320,15 @@ func TestComplete(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
progress := &sources.Progress{
|
||||
SectionsCompleted: int32(tt.initialState.sectionsCompleted),
|
||||
SectionsRemaining: int32(tt.initialState.sectionsRemaining),
|
||||
EncodedResumeInfo: tt.initialState.resumeInfo,
|
||||
Message: tt.initialState.message,
|
||||
}
|
||||
tracker := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
|
||||
err := tracker.Complete(ctx, tt.completeMessage)
|
||||
tracker, err := NewProgressTracker(ctx, tt.enabled, progress)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = tracker.Complete(ctx, tt.completeMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(tt.wantState.sectionsCompleted), progress.SectionsCompleted)
|
||||
assert.Equal(t, int32(tt.wantState.sectionsRemaining), progress.SectionsRemaining)
|
||||
assert.Equal(t, tt.wantState.resumeInfo, progress.EncodedResumeInfo)
|
||||
assert.Equal(t, tt.wantState.message, progress.Message)
|
||||
})
|
||||
|
|
Loading…
Reference in a new issue