[feat] - concurently scan the filesystem source (#2364)

* concurently scan the filesystem source

Co-authored-by: Miccah Castorina <m.castorina93@gmail.com>

* fix test

* update test

* remove return

* use error not info

* address comment

---------

Co-authored-by: Miccah Castorina <m.castorina93@gmail.com>
This commit is contained in:
ahrav 2024-02-03 10:49:14 -08:00 committed by GitHub
parent 27b30e65ed
commit a22874f9f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 15 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/go-errors/errors"
"github.com/go-logr/logr"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
@ -26,13 +27,14 @@ import (
const SourceType = sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM
type Source struct {
name string
sourceId sources.SourceID
jobId sources.JobID
verify bool
paths []string
log logr.Logger
filter *common.Filter
name string
sourceId sources.SourceID
jobId sources.JobID
concurrency int
verify bool
paths []string
log logr.Logger
filter *common.Filter
sources.Progress
sources.CommonSourceUnitUnmarshaller
}
@ -57,9 +59,10 @@ func (s *Source) JobID() sources.JobID {
}
// Init returns an initialized Filesystem source.
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, _ int) error {
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
s.log = aCtx.Logger()
s.concurrency = concurrency
s.name = name
s.sourceId = sourceId
s.jobId = jobId
@ -102,16 +105,22 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
err = s.scanFile(ctx, cleanPath, chunksChan)
}
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logger.Info("error scanning filesystem", "error", err)
}
}
return nil
}
func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sources.Chunk) error {
workerPool := new(errgroup.Group)
workerPool.SetLimit(s.concurrency)
defer func() { _ = workerPool.Wait() }()
return fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error {
if err != nil {
ctx.Logger().Error(err, "error walking directory")
return nil
}
fullPath := filepath.Join(path, relativePath)
@ -126,9 +135,13 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
return nil
}
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Info("error scanning file", "path", fullPath, "error", err)
}
workerPool.Go(func() error {
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Error(err, "error scanning file", "path", fullPath, "error", err)
}
return nil
})
return nil
})
}

View file

@ -53,6 +53,7 @@ func TestSource_Scan(t *testing.T) {
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := Source{}
@ -71,16 +72,23 @@ func TestSource_Scan(t *testing.T) {
// TODO: this is kind of bad, if it errors right away we don't see it as a test failure.
// Debugging this usually requires setting a breakpoint on L78 and running test w/ debug.
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
return
}
}()
gotChunk := <-chunksCh
if diff := pretty.Compare(gotChunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
var counter int
for chunk := range chunksCh {
if chunk.SourceMetadata.GetFilesystem().GetFile() == "filesystem.go" {
counter++
if diff := pretty.Compare(chunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
}
}
}
assert.Equal(t, 1, counter)
})
}
}