mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Use persistable cache for GCS progress tracking (#1204)
* Add in-memory caching lib, used by the GCS source. * Use cache for tracking progress for the GCS source. * fix merge issue. * fix merge issue. * fix test. * Fix static check. * Add test for NewWithData. * Use cache for tracking progress for the GCS source. * fix merge issue. * fix merge issue. * fix test. * update comment. * update comments. * Use cache for tracking progress for the GCS source. * fix merge issue. * fix merge issue. * fix test. * remove unused dep. * address comments. * Add exists method. * Use cache for tracking progress for the GCS source. * fix merge issue. * fix merge issue. * fix test. * rebase. * fix test. * Use cache for tracking progress for the GCS source. * fix merge issue. * fix merge issue. * fix test. * rebase. * rebase. * split encode resume by comma. * Use a persistable cache. * fix merge. * fix merge. * Add progress as part of the cache given it will be the persistence layer. * Add test for making sure the cache doesn't persist when the increment value is not met. * fix tests.
This commit is contained in:
parent
f107e1b497
commit
c451f9daf8
2 changed files with 189 additions and 6 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
|
@ -17,6 +18,8 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
|
@ -25,6 +28,8 @@ import (
|
|||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
)
|
||||
|
||||
const defaultCachePersistIncrement = 2500
|
||||
|
||||
// Ensure the Source satisfies the interface at compile time.
|
||||
var _ sources.Source = (*Source)(nil)
|
||||
|
||||
|
@ -57,7 +62,41 @@ type Source struct {
|
|||
log logr.Logger
|
||||
chunksCh chan *sources.Chunk
|
||||
|
||||
sources.Progress
|
||||
mu sync.Mutex
|
||||
sources.Progress // progress is not thread safe
|
||||
}
|
||||
|
||||
// persistableCache is a wrapper around cache.Cache that allows
|
||||
// for the persistence of the cache contents in the Progress of the source
|
||||
// at given increments.
|
||||
type persistableCache struct {
|
||||
persistIncrement int
|
||||
cache.Cache
|
||||
*sources.Progress
|
||||
}
|
||||
|
||||
func newPersistableCache(increment int, cache cache.Cache, p *sources.Progress) *persistableCache {
|
||||
return &persistableCache{
|
||||
persistIncrement: increment,
|
||||
Cache: cache,
|
||||
Progress: p,
|
||||
}
|
||||
}
|
||||
|
||||
// Set overrides the cache Set method of the cache to enable the persistence
|
||||
// of the cache contents the Progress of the source at given increments.
|
||||
func (c *persistableCache) Set(key, val string) {
|
||||
c.Cache.Set(key, val)
|
||||
if ok, contents := c.shouldPersist(); ok {
|
||||
c.Progress.EncodedResumeInfo = contents
|
||||
}
|
||||
}
|
||||
|
||||
func (c *persistableCache) shouldPersist() (bool, string) {
|
||||
if c.Count()%c.persistIncrement != 0 {
|
||||
return false, ""
|
||||
}
|
||||
return true, c.Contents()
|
||||
}
|
||||
|
||||
// Init returns an initialized GCS source.
|
||||
|
@ -190,6 +229,7 @@ func setGCSManagerOptions(include, exclude []string, includeFn, excludeFn func([
|
|||
}
|
||||
|
||||
// enumerate all the objects and buckets in the source.
|
||||
// This will be used to calculate progress.
|
||||
func (s *Source) enumerate(ctx context.Context) error {
|
||||
stats, err := s.gcsManager.attributes(ctx)
|
||||
if err != nil {
|
||||
|
@ -202,11 +242,14 @@ func (s *Source) enumerate(ctx context.Context) error {
|
|||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||
persistableCache := s.setupCache(ctx)
|
||||
|
||||
objectCh, err := s.gcsManager.listObjects(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listing objects: %w", err)
|
||||
}
|
||||
s.chunksCh = chunksChan
|
||||
s.Progress.Message = "starting to process objects..."
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for obj := range objectCh {
|
||||
|
@ -217,6 +260,11 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
continue
|
||||
}
|
||||
|
||||
if persistableCache.Exists(o.name) {
|
||||
ctx.Logger().V(5).Info("skipping object, object already processed", "name", o.name)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(obj object) {
|
||||
defer wg.Done()
|
||||
|
@ -225,13 +273,46 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
ctx.Logger().V(1).Info("error setting start progress progress", "name", o.name, "error", err)
|
||||
return
|
||||
}
|
||||
s.setProgress(ctx, o.name, persistableCache)
|
||||
}(o)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
s.completeProgress(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) setupCache(ctx context.Context) *persistableCache {
|
||||
var c cache.Cache
|
||||
if s.Progress.EncodedResumeInfo != "" {
|
||||
c = memory.NewWithData(ctx, strings.Split(s.Progress.EncodedResumeInfo, ","))
|
||||
} else {
|
||||
c = memory.New()
|
||||
}
|
||||
|
||||
// TODO (ahrav): Make this configurable via conn.
|
||||
persistCache := newPersistableCache(defaultCachePersistIncrement, c, &s.Progress)
|
||||
return persistCache
|
||||
}
|
||||
|
||||
func (s *Source) setProgress(ctx context.Context, objName string, cache cache.Cache) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx.Logger().V(5).Info("setting progress for object", "object-name", objName)
|
||||
s.SectionsCompleted++
|
||||
|
||||
cache.Set(objName, objName)
|
||||
s.Progress.SectionsRemaining = int32(s.stats.numObjects)
|
||||
s.Progress.PercentComplete = int64(float64(s.SectionsCompleted) / float64(s.stats.numObjects) * 100)
|
||||
}
|
||||
|
||||
func (s *Source) completeProgress(ctx context.Context) {
|
||||
msg := fmt.Sprintf("GCS source finished processing %d objects", s.stats.numObjects)
|
||||
ctx.Logger().Info(msg)
|
||||
s.Progress.Message = msg
|
||||
}
|
||||
|
||||
func (s *Source) processObject(ctx context.Context, o object) error {
|
||||
chunkSkel := &sources.Chunk{
|
||||
SourceName: s.name,
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -203,7 +205,9 @@ func TestSourceOauth2Client(t *testing.T) {
|
|||
}
|
||||
|
||||
type mockObjectManager struct {
|
||||
wantErr bool
|
||||
// numObjects is the number of objects to return in the listObjects call.
|
||||
numObjects int
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func (m *mockObjectManager) attributes(_ context.Context) (*attributes, error) {
|
||||
|
@ -212,7 +216,7 @@ func (m *mockObjectManager) attributes(_ context.Context) (*attributes, error) {
|
|||
}
|
||||
|
||||
return &attributes{
|
||||
numObjects: 5,
|
||||
numObjects: uint64(m.numObjects),
|
||||
numBuckets: 1,
|
||||
bucketObjects: map[string]uint64{testBucket: 5},
|
||||
}, nil
|
||||
|
@ -242,7 +246,7 @@ func (m *mockObjectManager) listObjects(context.Context) (chan io.Reader, error)
|
|||
go func() {
|
||||
defer close(ch)
|
||||
// Add 5 objects to the channel.
|
||||
for i := 0; i < 5; i++ {
|
||||
for i := 0; i < m.numObjects; i++ {
|
||||
ch <- createTestObject(i)
|
||||
}
|
||||
}()
|
||||
|
@ -292,8 +296,9 @@ func TestSourceChunks_ListObjects(t *testing.T) {
|
|||
chunksCh := make(chan *sources.Chunk, 1)
|
||||
|
||||
source := &Source{
|
||||
gcsManager: &mockObjectManager{},
|
||||
gcsManager: &mockObjectManager{numObjects: 5},
|
||||
chunksCh: chunksCh,
|
||||
Progress: sources.Progress{},
|
||||
}
|
||||
|
||||
err := source.enumerate(ctx)
|
||||
|
@ -338,7 +343,7 @@ func TestSourceChunks_ListObjects(t *testing.T) {
|
|||
|
||||
func TestSourceInit_Enumerate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
source := &Source{gcsManager: &mockObjectManager{}}
|
||||
source := &Source{gcsManager: &mockObjectManager{numObjects: 5}}
|
||||
|
||||
err := source.enumerate(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
@ -359,3 +364,100 @@ func TestSourceChunks_ListObjects_Error(t *testing.T) {
|
|||
err := source.Chunks(ctx, chunksCh)
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
|
||||
func TestSourceChunks_ProgressSet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
chunksCh := make(chan *sources.Chunk, 1)
|
||||
source := &Source{
|
||||
gcsManager: &mockObjectManager{numObjects: defaultCachePersistIncrement},
|
||||
chunksCh: chunksCh,
|
||||
Progress: sources.Progress{},
|
||||
}
|
||||
|
||||
err := source.enumerate(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err := source.Chunks(ctx, chunksCh)
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
|
||||
want := make([]*sources.Chunk, 0, defaultCachePersistIncrement)
|
||||
for i := 0; i < defaultCachePersistIncrement; i++ {
|
||||
want = append(want, createTestSourceChunk(i))
|
||||
}
|
||||
|
||||
got := make([]*sources.Chunk, 0, defaultCachePersistIncrement)
|
||||
for ch := range chunksCh {
|
||||
got = append(got, ch)
|
||||
}
|
||||
|
||||
// Ensure we get 2500 objects back.
|
||||
assert.Equal(t, len(want), len(got))
|
||||
|
||||
// Test that the resume progress is set.
|
||||
var progress strings.Builder
|
||||
for i := range got {
|
||||
progress.WriteString(fmt.Sprintf("object%d", i))
|
||||
// Add a comma if not the last element.
|
||||
if i != len(got)-1 {
|
||||
progress.WriteString(",")
|
||||
}
|
||||
}
|
||||
|
||||
encodeResume := strings.Split(source.Progress.EncodedResumeInfo, ",")
|
||||
sort.Slice(encodeResume, func(i, j int) bool {
|
||||
numI, _ := strconv.Atoi(strings.TrimPrefix(encodeResume[i], "object"))
|
||||
numJ, _ := strconv.Atoi(strings.TrimPrefix(encodeResume[j], "object"))
|
||||
return numI < numJ
|
||||
})
|
||||
|
||||
assert.Equal(t, progress.String(), strings.Join(encodeResume, ","))
|
||||
assert.Equal(t, int32(defaultCachePersistIncrement), source.Progress.SectionsCompleted)
|
||||
assert.Equal(t, int64(100), source.Progress.PercentComplete)
|
||||
assert.Equal(t, fmt.Sprintf("GCS source finished processing %d objects", defaultCachePersistIncrement), source.Progress.Message)
|
||||
}
|
||||
|
||||
func TestSource_CachePersistence(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
wantObjCnt := 4 // ensure we have less objects than the cache increment
|
||||
mockObjManager := &mockObjectManager{numObjects: wantObjCnt}
|
||||
|
||||
chunksCh := make(chan *sources.Chunk, 1)
|
||||
source := &Source{
|
||||
gcsManager: mockObjManager,
|
||||
chunksCh: chunksCh,
|
||||
Progress: sources.Progress{},
|
||||
}
|
||||
|
||||
err := source.enumerate(ctx)
|
||||
assert.Nil(t, err)
|
||||
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err := source.Chunks(ctx, chunksCh)
|
||||
assert.Nil(t, err)
|
||||
}()
|
||||
|
||||
want := make([]*sources.Chunk, 0, wantObjCnt)
|
||||
for i := 0; i < wantObjCnt; i++ {
|
||||
want = append(want, createTestSourceChunk(i))
|
||||
}
|
||||
|
||||
got := make([]*sources.Chunk, 0, wantObjCnt)
|
||||
for ch := range chunksCh {
|
||||
got = append(got, ch)
|
||||
}
|
||||
|
||||
// Ensure we get 4 objects back.
|
||||
assert.Equal(t, len(want), len(got))
|
||||
|
||||
// Test that the resume progress is empty.
|
||||
// The cache should not have been persisted.
|
||||
assert.Equal(t, "", source.Progress.EncodedResumeInfo)
|
||||
assert.Equal(t, int32(wantObjCnt), source.Progress.SectionsCompleted)
|
||||
assert.Equal(t, int64(100), source.Progress.PercentComplete)
|
||||
assert.Equal(t, fmt.Sprintf("GCS source finished processing %d objects", wantObjCnt), source.Progress.Message)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue