Add optional param to Chunks (#1747)

* Add interface for targeted chunking.

* use optional args.

* update Chunks method signature.

* update tests.

* fix test.

* update QueryCriteria type.
This commit is contained in:
ahrav 2023-09-07 09:03:37 -07:00 committed by GitHub
parent f6512ac4ca
commit 2a9f34962d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 36 additions and 19 deletions

View file

@ -75,7 +75,7 @@ func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, ver
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
projects, err := s.projects(ctx)
if err != nil {
return fmt.Errorf("error getting projects: %w", err)

View file

@ -85,7 +85,7 @@ type layerInfo struct {
}
// Chunks emits data over a channel that is decoded and scanned for secrets.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
ctx = context.WithValues(ctx, "source_type", s.Type(), "source_name", s.name)
workers := new(errgroup.Group)

View file

@ -77,7 +77,7 @@ func (s *Source) WithFilter(filter *common.Filter) {
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
for i, path := range s.paths {
logger := ctx.Logger().WithValues("path", path)
if common.IsDone(ctx) {

View file

@ -248,7 +248,7 @@ func (s *Source) enumerate(ctx context.Context) error {
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
persistableCache := s.setupCache(ctx)
objectCh, err := s.gcsManager.ListObjects(ctx)

View file

@ -153,7 +153,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
if err := s.scanRepos(ctx, chunksChan); err != nil {
return err
}

View file

@ -413,7 +413,7 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
apiEndpoint := s.conn.Endpoint
if len(apiEndpoint) == 0 || endsWithGithub.MatchString(apiEndpoint) {
apiEndpoint = "https://api.github.com"

View file

@ -139,7 +139,7 @@ func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, ver
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
// Start client.
apiClient, err := s.newClient()
if err != nil {

View file

@ -232,7 +232,7 @@ func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bu
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
}

View file

@ -6,10 +6,11 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"google.golang.org/protobuf/types/known/anypb"
)
// DummySource implements Source and is used for testing a SourceManager.
@ -31,7 +32,7 @@ func (d *DummySource) GetProgress() *Progress { return nil }
// Interface to easily test different chunking methods.
type chunker interface {
Chunks(context.Context, chan *Chunk) error
Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error
ChunkUnit(ctx context.Context, unit SourceUnit, reporter ChunkReporter) error
Enumerate(ctx context.Context, reporter UnitReporter) error
}
@ -42,7 +43,7 @@ type counterChunker struct {
count int
}
func (c *counterChunker) Chunks(ctx context.Context, ch chan *Chunk) error {
func (c *counterChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
for i := 0; i < c.count; i++ {
select {
case ch <- &Chunk{Data: []byte{c.chunkCounter}}:
@ -75,9 +76,9 @@ func (c *counterChunker) ChunkUnit(ctx context.Context, unit SourceUnit, reporte
// Chunk method that always returns an error.
type errorChunker struct{ error }
func (c errorChunker) Chunks(context.Context, chan *Chunk) error { return c }
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
func (c errorChunker) Chunks(context.Context, chan *Chunk, ...ChunkingTarget) error { return c }
func (c errorChunker) Enumerate(context.Context, UnitReporter) error { return c }
func (c errorChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return c }
// enrollDummy is a helper function to enroll a DummySource with a SourceManager.
func enrollDummy(mgr *SourceManager, chunkMethod chunker) (handle, error) {
@ -176,7 +177,7 @@ type unitChunk struct {
type unitChunker struct{ steps []unitChunk }
func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk) error {
func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
for _, step := range c.steps {
if step.err != "" {
continue
@ -294,7 +295,9 @@ type callbackChunker struct {
cb func(context.Context, chan *Chunk) error
}
func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk) error { return c.cb(ctx, ch) }
func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingTarget) error {
return c.cb(ctx, ch)
}
func (c callbackChunker) Enumerate(context.Context, UnitReporter) error { return nil }
func (c callbackChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return nil }

View file

@ -30,6 +30,16 @@ type Chunk struct {
Verify bool
}
// ChunkingTarget specifies criteria for a targeted chunking process.
// Instead of collecting data indiscriminately, this struct allows the caller
// to specify particular subsets of data they're interested in. This becomes
// especially useful when one needs to verify or recheck specific data points
// without processing the entire dataset.
type ChunkingTarget struct {
// QueryCriteria represents specific parameters or conditions to target the chunking process.
QueryCriteria source_metadatapb.MetaData
}
// Source defines the interface required to implement a source chunker.
type Source interface {
// Type returns the source type, used for matching against configuration and jobs.
@ -40,8 +50,12 @@ type Source interface {
JobID() int64
// Init initializes the source.
Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error
// Chunks emits data over a channel that is decoded and scanned for secrets.
Chunks(ctx context.Context, chunksChan chan *Chunk) error
// Chunks emits data over a channel which is then decoded and scanned for secrets.
// By default, data is obtained indiscriminately. However, by providing one or more
// ChunkingTarget parameters, the caller can direct the function to retrieve
// specific chunks of data. This targeted approach allows for efficient and
// intentional data processing, beneficial when verifying or rechecking specific data points.
Chunks(ctx context.Context, chunksChan chan *Chunk, targets ...ChunkingTarget) error
// GetProgress is the completion progress (percentage) for Scanned Source.
GetProgress() *Progress
}

View file

@ -182,7 +182,7 @@ func (s *Source) verifyConnectionConfig() error {
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
switch {
case s.conn.TlsCert != nilString || s.conn.TlsKey != nilString:
cert, err := tls.X509KeyPair([]byte(s.conn.TlsCert), []byte(s.conn.TlsKey))