Use persistable cache for GCS progress tracking (#1204)

* Add in-memory caching lib, used by the GCS source.

* Use cache for tracking progress for the GCS source.

* fix merge issue.

* fix merge issue.

* fix test.

* Fix static check.

* Add test for NewWithData.

* Use cache for tracking progress for the GCS source.

* fix merge issue.

* fix merge issue.

* fix test.

* update comment.

* update comments.

* Use cache for tracking progress for the GCS source.

* fix merge issue.

* fix merge issue.

* fix test.

* remove unused dep.

* address comments.

* Add exists method.

* Use cache for tracking progress for the GCS source.

* fix merge issue.

* fix merge issue.

* fix test.

* rebase.

* fix test.

* Use cache for tracking progress for the GCS source.

* fix merge issue.

* fix merge issue.

* fix test.

* rebase.

* rebase.

* split encode resume by comma.

* Use a persistable cache.

* fix merge.

* fix merge.

* Add progress as part of the cache given it will be the persistence layer.

* Add test for making sure the cache doesn't persist when the increment value is not met.

* fix tests.
This commit is contained in:
ahrav 2023-04-10 07:55:00 -07:00 committed by GitHub
parent f107e1b497
commit c451f9daf8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 6 deletions

View file

@ -6,6 +6,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"sync"
"cloud.google.com/go/storage"
@ -17,6 +18,8 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
@ -25,6 +28,8 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
const defaultCachePersistIncrement = 2500
// Ensure the Source satisfies the interface at compile time.
var _ sources.Source = (*Source)(nil)
@ -57,7 +62,41 @@ type Source struct {
log logr.Logger
chunksCh chan *sources.Chunk
sources.Progress
mu sync.Mutex
sources.Progress // progress is not thread safe
}
// persistableCache is a wrapper around cache.Cache that allows
// for the persistence of the cache contents in the Progress of the source
// at given increments.
type persistableCache struct {
persistIncrement int
cache.Cache
*sources.Progress
}
func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache {
return &persistableCache{
persistIncrement: increment,
Cache: cache,
Progress: p,
}
}
// Set overrides the cache Set method of the cache to enable the persistence
// of the cache contents the Progress of the source at given increments.
func (c *persistableCache) Set(key, val string) {
c.Cache.Set(key, val)
if ok, contents := c.shouldPersist(); ok {
c.Progress.EncodedResumeInfo = contents
}
}
func (c *persistableCache) shouldPersist() (bool, string) {
if c.Count()%c.persistIncrement != 0 {
return false, ""
}
return true, c.Contents()
}
// Init returns an initialized GCS source.
@ -190,6 +229,7 @@ func setGCSManagerOptions(include, exclude []string, includeFn, excludeFn func([
}
// enumerate all the objects and buckets in the source.
// This will be used to calculate progress.
func (s *Source) enumerate(ctx context.Context) error {
stats, err := s.gcsManager.attributes(ctx)
if err != nil {
@ -202,11 +242,14 @@ func (s *Source) enumerate(ctx context.Context) error {
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
persistableCache := s.setupCache(ctx)
objectCh, err := s.gcsManager.listObjects(ctx)
if err != nil {
return fmt.Errorf("error listing objects: %w", err)
}
s.chunksCh = chunksChan
s.Progress.Message = "starting to process objects..."
var wg sync.WaitGroup
for obj := range objectCh {
@ -217,6 +260,11 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
continue
}
if persistableCache.Exists(o.name) {
ctx.Logger().V(5).Info("skipping object, object already processed", "name", o.name)
continue
}
wg.Add(1)
go func(obj object) {
defer wg.Done()
@ -225,13 +273,46 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
ctx.Logger().V(1).Info("error setting start progress progress", "name", o.name, "error", err)
return
}
s.setProgress(ctx, o.name, persistableCache)
}(o)
}
wg.Wait()
s.completeProgress(ctx)
return nil
}
func (s *Source) setupCache(ctx context.Context) *persistableCache {
var c cache.Cache
if s.Progress.EncodedResumeInfo != "" {
c = memory.NewWithData(ctx, strings.Split(s.Progress.EncodedResumeInfo, ","))
} else {
c = memory.New()
}
// TODO (ahrav): Make this configurable via conn.
persistCache := newPersistableCache(defaultCachePersistIncrement, c, &s.Progress)
return persistCache
}
func (s *Source) setProgress(ctx context.Context, objName string, cache cache.Cache) {
s.mu.Lock()
defer s.mu.Unlock()
ctx.Logger().V(5).Info("setting progress for object", "object-name", objName)
s.SectionsCompleted++
cache.Set(objName, objName)
s.Progress.SectionsRemaining = int32(s.stats.numObjects)
s.Progress.PercentComplete = int64(float64(s.SectionsCompleted) / float64(s.stats.numObjects) * 100)
}
func (s *Source) completeProgress(ctx context.Context) {
msg := fmt.Sprintf("GCS source finished processing %d objects", s.stats.numObjects)
ctx.Logger().Info(msg)
s.Progress.Message = msg
}
func (s *Source) processObject(ctx context.Context, o object) error {
chunkSkel := &sources.Chunk{
SourceName: s.name,

View file

@ -5,6 +5,8 @@ import (
"io"
"net/http"
"sort"
"strconv"
"strings"
"testing"
"time"
@ -203,7 +205,9 @@ func TestSourceOauth2Client(t *testing.T) {
}
type mockObjectManager struct {
wantErr bool
// numObjects is the number of objects to return in the listObjects call.
numObjects int
wantErr bool
}
func (m *mockObjectManager) attributes(_ context.Context) (*attributes, error) {
@ -212,7 +216,7 @@ func (m *mockObjectManager) attributes(_ context.Context) (*attributes, error) {
}
return &attributes{
numObjects: 5,
numObjects: uint64(m.numObjects),
numBuckets: 1,
bucketObjects: map[string]uint64{testBucket: 5},
}, nil
@ -242,7 +246,7 @@ func (m *mockObjectManager) listObjects(context.Context) (chan io.Reader, error)
go func() {
defer close(ch)
// Add 5 objects to the channel.
for i := 0; i < 5; i++ {
for i := 0; i < m.numObjects; i++ {
ch <- createTestObject(i)
}
}()
@ -292,8 +296,9 @@ func TestSourceChunks_ListObjects(t *testing.T) {
chunksCh := make(chan *sources.Chunk, 1)
source := &Source{
gcsManager: &mockObjectManager{},
gcsManager: &mockObjectManager{numObjects: 5},
chunksCh: chunksCh,
Progress: sources.Progress{},
}
err := source.enumerate(ctx)
@ -338,7 +343,7 @@ func TestSourceChunks_ListObjects(t *testing.T) {
func TestSourceInit_Enumerate(t *testing.T) {
ctx := context.Background()
source := &Source{gcsManager: &mockObjectManager{}}
source := &Source{gcsManager: &mockObjectManager{numObjects: 5}}
err := source.enumerate(ctx)
assert.Nil(t, err)
@ -359,3 +364,100 @@ func TestSourceChunks_ListObjects_Error(t *testing.T) {
err := source.Chunks(ctx, chunksCh)
assert.True(t, err != nil)
}
func TestSourceChunks_ProgressSet(t *testing.T) {
ctx := context.Background()
chunksCh := make(chan *sources.Chunk, 1)
source := &Source{
gcsManager: &mockObjectManager{numObjects: defaultCachePersistIncrement},
chunksCh: chunksCh,
Progress: sources.Progress{},
}
err := source.enumerate(ctx)
assert.Nil(t, err)
go func() {
defer close(chunksCh)
err := source.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()
want := make([]*sources.Chunk, 0, defaultCachePersistIncrement)
for i := 0; i < defaultCachePersistIncrement; i++ {
want = append(want, createTestSourceChunk(i))
}
got := make([]*sources.Chunk, 0, defaultCachePersistIncrement)
for ch := range chunksCh {
got = append(got, ch)
}
// Ensure we get 2500 objects back.
assert.Equal(t, len(want), len(got))
// Test that the resume progress is set.
var progress strings.Builder
for i := range got {
progress.WriteString(fmt.Sprintf("object%d", i))
// Add a comma if not the last element.
if i != len(got)-1 {
progress.WriteString(",")
}
}
encodeResume := strings.Split(source.Progress.EncodedResumeInfo, ",")
sort.Slice(encodeResume, func(i, j int) bool {
numI, _ := strconv.Atoi(strings.TrimPrefix(encodeResume[i], "object"))
numJ, _ := strconv.Atoi(strings.TrimPrefix(encodeResume[j], "object"))
return numI < numJ
})
assert.Equal(t, progress.String(), strings.Join(encodeResume, ","))
assert.Equal(t, int32(defaultCachePersistIncrement), source.Progress.SectionsCompleted)
assert.Equal(t, int64(100), source.Progress.PercentComplete)
assert.Equal(t, fmt.Sprintf("GCS source finished processing %d objects", defaultCachePersistIncrement), source.Progress.Message)
}
func TestSource_CachePersistence(t *testing.T) {
ctx := context.Background()
wantObjCnt := 4 // ensure we have less objects than the cache increment
mockObjManager := &mockObjectManager{numObjects: wantObjCnt}
chunksCh := make(chan *sources.Chunk, 1)
source := &Source{
gcsManager: mockObjManager,
chunksCh: chunksCh,
Progress: sources.Progress{},
}
err := source.enumerate(ctx)
assert.Nil(t, err)
go func() {
defer close(chunksCh)
err := source.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()
want := make([]*sources.Chunk, 0, wantObjCnt)
for i := 0; i < wantObjCnt; i++ {
want = append(want, createTestSourceChunk(i))
}
got := make([]*sources.Chunk, 0, wantObjCnt)
for ch := range chunksCh {
got = append(got, ch)
}
// Ensure we get 4 objects back.
assert.Equal(t, len(want), len(got))
// Test that the resume progress is empty.
// The cache should not have been persisted.
assert.Equal(t, "", source.Progress.EncodedResumeInfo)
assert.Equal(t, int32(wantObjCnt), source.Progress.SectionsCompleted)
assert.Equal(t, int64(100), source.Progress.PercentComplete)
assert.Equal(t, fmt.Sprintf("GCS source finished processing %d objects", wantObjCnt), source.Progress.Message)
}