mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[feat] - concurently scan the filesystem source (#2364)
* concurently scan the filesystem source Co-authored-by: Miccah Castorina <m.castorina93@gmail.com> * fix test * update test * remove return * use error not info * address comment --------- Co-authored-by: Miccah Castorina <m.castorina93@gmail.com>
This commit is contained in:
parent
27b30e65ed
commit
a22874f9f0
2 changed files with 36 additions and 15 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/go-errors/errors"
|
||||
"github.com/go-logr/logr"
|
||||
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
|
@ -26,13 +27,14 @@ import (
|
|||
const SourceType = sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM
|
||||
|
||||
type Source struct {
|
||||
name string
|
||||
sourceId sources.SourceID
|
||||
jobId sources.JobID
|
||||
verify bool
|
||||
paths []string
|
||||
log logr.Logger
|
||||
filter *common.Filter
|
||||
name string
|
||||
sourceId sources.SourceID
|
||||
jobId sources.JobID
|
||||
concurrency int
|
||||
verify bool
|
||||
paths []string
|
||||
log logr.Logger
|
||||
filter *common.Filter
|
||||
sources.Progress
|
||||
sources.CommonSourceUnitUnmarshaller
|
||||
}
|
||||
|
@ -57,9 +59,10 @@ func (s *Source) JobID() sources.JobID {
|
|||
}
|
||||
|
||||
// Init returns an initialized Filesystem source.
|
||||
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, _ int) error {
|
||||
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
|
||||
s.log = aCtx.Logger()
|
||||
|
||||
s.concurrency = concurrency
|
||||
s.name = name
|
||||
s.sourceId = sourceId
|
||||
s.jobId = jobId
|
||||
|
@ -102,16 +105,22 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
|
|||
err = s.scanFile(ctx, cleanPath, chunksChan)
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Info("error scanning filesystem", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sources.Chunk) error {
|
||||
workerPool := new(errgroup.Group)
|
||||
workerPool.SetLimit(s.concurrency)
|
||||
defer func() { _ = workerPool.Wait() }()
|
||||
|
||||
return fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
ctx.Logger().Error(err, "error walking directory")
|
||||
return nil
|
||||
}
|
||||
fullPath := filepath.Join(path, relativePath)
|
||||
|
@ -126,9 +135,13 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
|
||||
ctx.Logger().Info("error scanning file", "path", fullPath, "error", err)
|
||||
}
|
||||
workerPool.Go(func() error {
|
||||
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
|
||||
ctx.Logger().Error(err, "error scanning file", "path", fullPath, "error", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestSource_Scan(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := Source{}
|
||||
|
@ -71,16 +72,23 @@ func TestSource_Scan(t *testing.T) {
|
|||
// TODO: this is kind of bad, if it errors right away we don't see it as a test failure.
|
||||
// Debugging this usually requires setting a breakpoint on L78 and running test w/ debug.
|
||||
go func() {
|
||||
defer close(chunksCh)
|
||||
err = s.Chunks(ctx, chunksCh)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
}()
|
||||
gotChunk := <-chunksCh
|
||||
if diff := pretty.Compare(gotChunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
|
||||
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
|
||||
var counter int
|
||||
for chunk := range chunksCh {
|
||||
if chunk.SourceMetadata.GetFilesystem().GetFile() == "filesystem.go" {
|
||||
counter++
|
||||
if diff := pretty.Compare(chunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
|
||||
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, counter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue