Add concurrency to CircleCi source (#1029)

* Small cleanup of CircleCi source.

* Add concurrency to circleci.

* merge w/ cleanup branch.

* Rdefine loop var.

* Delete github.go

* reverge file delete.

* Add debug log for scan errors.

* make collecting scanned errors thread safe.

* pre-allocate errors slice.
This commit is contained in:
ahrav 2023-01-17 12:24:49 -08:00 committed by GitHub
parent 319ae64a02
commit 1621403e11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,8 +6,12 @@ import (
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"github.com/go-errors/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
@ -26,6 +30,7 @@ type Source struct {
sourceId int64
jobId int64
verify bool
jobPool *errgroup.Group
sources.Progress
client *http.Client
}
@ -48,11 +53,13 @@ func (s *Source) JobID() int64 {
}
// Init returns an initialized CircleCI source.
func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, _ int) error {
func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.name = name
s.sourceId = sourceId
s.jobId = jobId
s.verify = verify
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)
s.client = common.RetryableHttpClientTimeout(3)
var conn sourcespb.CircleCI
@ -68,34 +75,70 @@ func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, ver
return nil
}
// scanErrors is used to collect errors encountered while scanning.
// It ensures that errors are collected in a thread-safe manner.
type scanErrors struct {
count uint64
mu sync.Mutex
errors []error
}
func newScanErrors(projects int) *scanErrors {
return &scanErrors{errors: make([]error, 0, projects)}
}
func (s *scanErrors) add(err error) {
atomic.AddUint64(&s.count, 1)
s.mu.Lock()
s.errors = append(s.errors, err)
s.mu.Unlock()
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
projects, err := s.projects(ctx)
if err != nil {
return err
return fmt.Errorf("error getting projects: %w", err)
}
for _, proj := range projects {
builds, err := s.buildsForProject(ctx, proj)
if err != nil {
return err
}
var scanned uint64
scanErrs := newScanErrors(len(projects))
for _, bld := range builds {
buildSteps, err := s.stepsForBuild(ctx, proj, bld)
for _, proj := range projects {
proj := proj
s.jobPool.Go(func() error {
builds, err := s.buildsForProject(ctx, proj)
if err != nil {
return err
scanErrs.add(fmt.Errorf("error getting builds for project %s: %w", proj.RepoName, err))
return nil
}
for _, step := range buildSteps {
for _, action := range step.Actions {
err = s.chunkAction(ctx, proj, bld, action, step.Name, chunksChan)
if err != nil {
return err
for _, bld := range builds {
buildSteps, err := s.stepsForBuild(ctx, proj, bld)
if err != nil {
scanErrs.add(fmt.Errorf("error getting steps for build %d: %w", bld.BuildNum, err))
return nil
}
for _, step := range buildSteps {
for _, action := range step.Actions {
if err = s.chunkAction(ctx, proj, bld, action, step.Name, chunksChan); err != nil {
scanErrs.add(fmt.Errorf("error chunking action %v: %w", action, err))
return nil
}
}
}
}
}
atomic.AddUint64(&scanned, 1)
log.Debugf("scanned %d/%d projects", scanned, len(projects))
return nil
})
}
_ = s.jobPool.Wait()
if scanErrs.count > 0 {
log.Debugf("encountered %d errors while scanning; errors: %v", scanErrs.count, scanErrs)
}
return nil