mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 15:14:38 +00:00
Add concurrency to CircleCi source (#1029)
* Small cleanup of CircleCi source. * Add concurrency to circleci. * merge w/ cleanup branch. * Rdefine loop var. * Delete github.go * reverge file delete. * Add debug log for scan errors. * make collecting scanned errors thread safe. * pre-allocate errors slice.
This commit is contained in:
parent
319ae64a02
commit
1621403e11
1 changed files with 59 additions and 16 deletions
|
@ -6,8 +6,12 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-errors/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
|
@ -26,6 +30,7 @@ type Source struct {
|
|||
sourceId int64
|
||||
jobId int64
|
||||
verify bool
|
||||
jobPool *errgroup.Group
|
||||
sources.Progress
|
||||
client *http.Client
|
||||
}
|
||||
|
@ -48,11 +53,13 @@ func (s *Source) JobID() int64 {
|
|||
}
|
||||
|
||||
// Init returns an initialized CircleCI source.
|
||||
func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, _ int) error {
|
||||
func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
|
||||
s.name = name
|
||||
s.sourceId = sourceId
|
||||
s.jobId = jobId
|
||||
s.verify = verify
|
||||
s.jobPool = &errgroup.Group{}
|
||||
s.jobPool.SetLimit(concurrency)
|
||||
s.client = common.RetryableHttpClientTimeout(3)
|
||||
|
||||
var conn sourcespb.CircleCI
|
||||
|
@ -68,34 +75,70 @@ func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, ver
|
|||
return nil
|
||||
}
|
||||
|
||||
// scanErrors is used to collect errors encountered while scanning.
|
||||
// It ensures that errors are collected in a thread-safe manner.
|
||||
type scanErrors struct {
|
||||
count uint64
|
||||
mu sync.Mutex
|
||||
errors []error
|
||||
}
|
||||
|
||||
func newScanErrors(projects int) *scanErrors {
|
||||
return &scanErrors{errors: make([]error, 0, projects)}
|
||||
}
|
||||
|
||||
func (s *scanErrors) add(err error) {
|
||||
atomic.AddUint64(&s.count, 1)
|
||||
s.mu.Lock()
|
||||
s.errors = append(s.errors, err)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||
projects, err := s.projects(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error getting projects: %w", err)
|
||||
}
|
||||
|
||||
for _, proj := range projects {
|
||||
builds, err := s.buildsForProject(ctx, proj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var scanned uint64
|
||||
scanErrs := newScanErrors(len(projects))
|
||||
|
||||
for _, bld := range builds {
|
||||
buildSteps, err := s.stepsForBuild(ctx, proj, bld)
|
||||
for _, proj := range projects {
|
||||
proj := proj
|
||||
s.jobPool.Go(func() error {
|
||||
builds, err := s.buildsForProject(ctx, proj)
|
||||
if err != nil {
|
||||
return err
|
||||
scanErrs.add(fmt.Errorf("error getting builds for project %s: %w", proj.RepoName, err))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, step := range buildSteps {
|
||||
for _, action := range step.Actions {
|
||||
err = s.chunkAction(ctx, proj, bld, action, step.Name, chunksChan)
|
||||
if err != nil {
|
||||
return err
|
||||
for _, bld := range builds {
|
||||
buildSteps, err := s.stepsForBuild(ctx, proj, bld)
|
||||
if err != nil {
|
||||
scanErrs.add(fmt.Errorf("error getting steps for build %d: %w", bld.BuildNum, err))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, step := range buildSteps {
|
||||
for _, action := range step.Actions {
|
||||
if err = s.chunkAction(ctx, proj, bld, action, step.Name, chunksChan); err != nil {
|
||||
scanErrs.add(fmt.Errorf("error chunking action %v: %w", action, err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint64(&scanned, 1)
|
||||
log.Debugf("scanned %d/%d projects", scanned, len(projects))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = s.jobPool.Wait()
|
||||
if scanErrs.count > 0 {
|
||||
log.Debugf("encountered %d errors while scanning; errors: %v", scanErrs.count, scanErrs)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in a new issue