[THOG-681] - Handle errors sources (#783)

* Handle errors w/ github source.

* Fix loop var captured by func literal.

* Fix loop var captured by func literal.

* Set completed progress if the scan completes with no errors.

* Set progress to 100% if the scope and iteration are both 0.

* Fix commentary.

* Fix test.

* Return after the defer to os.RemoveAll.

* Fix unauth scan.

* Inline range loop.

* update tests for partial scan completion with errors. Ensure correct progress is set.

* Update progress for all sources.

* Update github test.

* Address comments.
This commit is contained in:
ahrav 2022-09-07 19:40:37 -07:00 committed by GitHub
parent c12be4d98d
commit 7ba583ca40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 199 additions and 118 deletions

View file

@ -195,6 +195,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
}
}
s.SetProgressComplete(len(s.conn.Repositories), len(s.conn.Repositories), fmt.Sprintf("Completed scanning source %s", s.name), "")
return nil
}

View file

@ -19,7 +19,7 @@ import (
"github.com/google/go-github/v42/github"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/sync/semaphore"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
@ -34,6 +34,12 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
)
const (
unauthGithubOrgRateLimt = 30
defaultPagination = 100
membersAppPagination = 500
)
type Source struct {
name string
sourceID int64
@ -48,7 +54,7 @@ type Source struct {
log *log.Entry
token string
conn *sourcespb.GitHub
jobSem *semaphore.Weighted
jobPool *errgroup.Group
resumeInfoSlice []string
resumeInfoMutex sync.Mutex
sources.Progress
@ -103,7 +109,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64,
s.sourceID = sourceID
s.jobID = jobID
s.verify = verify
s.jobSem = semaphore.NewWeighted(int64(concurrency))
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)
s.httpClient = common.SaneHttpClient()
@ -144,13 +151,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64,
func (s *Source) enumerateUnauthenticated(ctx context.Context) *github.Client {
apiClient := github.NewClient(s.httpClient)
if len(s.orgs) > 30 {
if len(s.orgs) > unauthGithubOrgRateLimt {
log.Warn("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}
for _, org := range s.orgs {
errOrg := s.addReposByOrg(ctx, apiClient, org)
errUser := s.addReposByUser(ctx, apiClient, org)
errOrg := s.addRepos(ctx, apiClient, org, s.getReposByOrg)
errUser := s.addRepos(ctx, apiClient, org, s.getReposByUser)
if errOrg != nil && errUser != nil {
log.WithError(errOrg).Error("error fetching repos for org or user: ", org)
}
@ -159,18 +166,18 @@ func (s *Source) enumerateUnauthenticated(ctx context.Context) *github.Client {
}
func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token string) (*github.Client, error) {
// needed for clones
// Needed for clones.
s.token = token
// needed to list repos
// Needed to list repos.
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
tc := oauth2.NewClient(context.TODO(), ts)
var err error
// If we're using public github, make a regular client.
// Otherwise make an enterprise client
// If we're using public Github, make a regular client.
// Otherwise, make an enterprise client.
var isGHE bool
var apiClient *github.Client
if apiEndpoint == "https://api.github.com" {
@ -191,25 +198,25 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
specificScope = true
}
user, _, err := apiClient.Users.Get(context.TODO(), "")
if err != nil {
return nil, errors.New(err)
}
if len(s.orgs) > 0 {
specificScope = true
for _, org := range s.orgs {
errOrg := s.addReposByOrg(ctx, apiClient, org)
errUser := s.addReposByUser(ctx, apiClient, org)
errOrg := s.addRepos(ctx, apiClient, org, s.getReposByOrg)
errUser := s.addRepos(ctx, apiClient, user.GetLogin(), s.getReposByUser)
if errOrg != nil && errUser != nil {
log.WithError(errOrg).Error("error fetching repos for org or user: ", org)
}
}
}
user, _, err := apiClient.Users.Get(context.TODO(), "")
if err != nil {
return nil, errors.New(err)
}
// If no scope was provided, enumerate them
// If no scope was provided, enumerate them.
if !specificScope {
if err := s.addReposByUser(ctx, apiClient, user.GetLogin()); err != nil {
if err := s.addRepos(ctx, apiClient, user.GetLogin(), s.getReposByUser); err != nil {
log.WithError(err).Error("error fetching repos by user")
}
@ -222,7 +229,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
}
for _, org := range s.orgs {
if err := s.addReposByOrg(ctx, apiClient, org); err != nil {
if err := s.addRepos(ctx, apiClient, org, s.getReposByOrg); err != nil {
log.WithError(err).Error("error fetching repos by org")
}
}
@ -251,7 +258,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return nil, nil, errors.New(err)
}
// This client is used for most APIs
// This client is used for most APIs.
itr, err := ghinstallation.New(
s.httpClient.Transport,
appID,
@ -266,8 +273,8 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return nil, nil, errors.New(err)
}
// This client is required to create installation tokens for cloning.. Otherwise the required JWT is not in the
// request for the token :/
// This client is required to create installation tokens for cloning.
// Otherwise, the required JWT is not in the request for the token :/
appItr, err := ghinstallation.NewAppsTransport(
s.httpClient.Transport,
appID,
@ -281,14 +288,13 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return nil, nil, errors.New(err)
}
// If no repos were provided, enumerate them
// If no repos were provided, enumerate them.
if len(s.repos) == 0 {
err = s.addReposByApp(ctx, apiClient)
if err != nil {
if err = s.addReposByApp(ctx, apiClient); err != nil {
return nil, nil, err
}
// check if we need to find user repos
// Check if we need to find user repos.
if s.conn.ScanUsers {
err := s.addMembersByApp(ctx, installationClient, apiClient)
if err != nil {
@ -296,8 +302,10 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
}
log.Infof("Scanning repos from %v organization members.", len(s.members))
for _, member := range s.members {
s.addGistsByUser(ctx, apiClient, member)
if err := s.addReposByUser(ctx, apiClient, member); err != nil {
if err = s.addGistsByUser(ctx, apiClient, member); err != nil {
return nil, nil, err
}
if err := s.addRepos(ctx, apiClient, member, s.getReposByUser); err != nil {
log.WithError(err).Error("error fetching repos by user")
}
}
@ -315,17 +323,16 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
}
var apiClient, installationClient *github.Client
var err error
switch cred := s.conn.GetCredential().(type) {
case *sourcespb.GitHub_Unauthenticated:
apiClient = s.enumerateUnauthenticated(ctx)
case *sourcespb.GitHub_Token:
var err error
if apiClient, err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil {
return err
}
case *sourcespb.GitHub_GithubApp:
var err error
if apiClient, installationClient, err = s.enumerateWithApp(ctx, apiEndpoint, cred.GithubApp); err != nil {
return err
}
@ -339,51 +346,41 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
// We must sort the repos so we can resume later if necessary.
sort.Strings(s.repos)
return s.scan(ctx, installationClient, chunksChan)
for _, err := range s.scan(ctx, installationClient, chunksChan) {
log.WithError(err).Error("error scanning repository")
}
return nil
}
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) []error {
var scanned uint64
log.Debugf("Found %v total repos to scan", len(s.repos))
wg := sync.WaitGroup{}
errs := make(chan error, 1)
reportErr := func(err error) {
// save the error if there's room, otherwise log and drop it
select {
case errs <- err:
default:
log.WithError(err).Warn("dropping error")
}
}
// If there is resume information available, limit this scan to only the repos that still need scanning.
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
s.repos = reposToScan
var scanErrs []error
for i, repoURL := range s.repos {
if err := s.jobSem.Acquire(ctx, 1); err != nil {
// Acquire blocks until it can acquire the semaphore or returns an
// error if the context is finished
log.WithError(err).Debug("could not acquire semaphore")
reportErr(err)
break
}
wg.Add(1)
go func(ctx context.Context, repoURL string, i int) {
defer s.jobSem.Release(1)
defer wg.Done()
repoURL := repoURL
s.jobPool.Go(func() error {
if common.IsDone(ctx) {
return nil
}
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
// Ensure the repo is removed from the resume info after being scanned.
defer func(s *Source) {
defer func(s *Source, repoURL string) {
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
}(s)
}(s, repoURL)
if !strings.HasSuffix(repoURL, ".git") {
return
scanErrs = append(scanErrs, fmt.Errorf("repo %s does not end in .git", repoURL))
return nil
}
s.log.WithField("repo", repoURL).Debugf("attempting to clone repo %d/%d", i+1, len(s.repos))
@ -394,20 +391,24 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
switch s.conn.GetCredential().(type) {
case *sourcespb.GitHub_Unauthenticated:
path, repo, err = git.CloneRepoUsingUnauthenticated(repoURL)
if err != nil {
scanErrs = append(scanErrs, fmt.Errorf("error cloning repo %s: %w", repoURL, err))
}
default:
var token string
token, err = s.Token(ctx, installationClient)
if err != nil {
reportErr(err)
return
scanErrs = append(scanErrs, fmt.Errorf("error getting token for repo %s: %w", repoURL, err))
}
path, repo, err = git.CloneRepoUsingToken(token, repoURL, "")
if err != nil {
scanErrs = append(scanErrs, fmt.Errorf("error cloning repo %s: %w", repoURL, err))
}
path, repo, err = git.CloneRepoUsingToken(token, repoURL, "clone")
}
defer os.RemoveAll(path)
if err != nil {
log.WithError(err).Errorf("unable to clone repo (%s), continuing", repoURL)
return
return nil
}
// Base and head will only exist from incoming webhooks.
scanOptions := git.NewScanOptions(
@ -415,24 +416,23 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
git.ScanOptionHeadCommit(s.conn.Head),
)
err = s.git.ScanRepo(ctx, repo, path, scanOptions, chunksChan)
if err != nil {
if err = s.git.ScanRepo(ctx, repo, path, scanOptions, chunksChan); err != nil {
log.WithError(err).Errorf("unable to scan repo, continuing")
return nil
}
atomic.AddUint64(&scanned, 1)
log.Debugf("scanned %d/%d repos", scanned, len(s.repos))
}(ctx, repoURL, i)
return nil
})
}
wg.Wait()
// This only returns first error which is what we did prior to concurrency
select {
case err := <-errs:
return err
default:
return nil
_ = s.jobPool.Wait()
if len(scanErrs) == 0 {
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed Github scan", "")
}
return scanErrs
}
// handleRateLimit returns true if a rate limit was handled
@ -472,12 +472,12 @@ func handleRateLimit(errIn error, res *github.Response) bool {
}
func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, org string) ([]string, error) {
log := s.log.WithField("org", org)
logger := s.log.WithField("org", org)
repos := []string{}
var repos []string
opts := &github.RepositoryListByOrgOptions{
ListOptions: github.ListOptions{
PerPage: 100,
PerPage: defaultPagination,
},
}
var numRepos, numForks int
@ -492,7 +492,7 @@ func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, or
if err != nil {
return nil, fmt.Errorf("could not list repos for org %s: %w", org, err)
}
log.Debugf("listed repos page %d/%d", opts.Page, res.LastPage)
logger.Debugf("listed repos page %d/%d", opts.Page, res.LastPage)
if len(someRepos) == 0 {
break
}
@ -511,16 +511,16 @@ func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, or
}
opts.Page = res.NextPage
}
log.Debugf("found %d repos (%d forks)", numRepos, numForks)
logger.Debugf("found %d repos (%d forks)", numRepos, numForks)
return repos, nil
}
func (s *Source) addReposByOrg(ctx context.Context, apiClient *github.Client, org string) error {
repos, err := s.getReposByOrg(ctx, apiClient, org)
func (s *Source) addRepos(ctx context.Context, client *github.Client, entity string, getRepos func(context.Context, *github.Client, string) ([]string, error)) error {
repos, err := getRepos(ctx, client, entity)
if err != nil {
return err
}
// add the repos to the set of repos
// Add the repos to the set of repos.
for _, repo := range repos {
common.AddStringSliceItem(repo, &s.repos)
}
@ -528,7 +528,7 @@ func (s *Source) addReposByOrg(ctx context.Context, apiClient *github.Client, or
}
func (s *Source) getReposByUser(ctx context.Context, apiClient *github.Client, user string) ([]string, error) {
repos := []string{}
var repos []string
opts := &github.RepositoryListOptions{
ListOptions: github.ListOptions{
PerPage: 50,
@ -559,20 +559,8 @@ func (s *Source) getReposByUser(ctx context.Context, apiClient *github.Client, u
return repos, nil
}
func (s *Source) addReposByUser(ctx context.Context, apiClient *github.Client, user string) error {
repos, err := s.getReposByUser(ctx, apiClient, user)
if err != nil {
return err
}
// add the repos to the set of repos
for _, repo := range repos {
common.AddStringSliceItem(repo, &s.repos)
}
return nil
}
func (s *Source) getGistsByUser(ctx context.Context, apiClient *github.Client, user string) ([]string, error) {
gistURLs := []string{}
var gistURLs []string
gistOpts := &github.GistListOptions{}
for {
gists, resp, err := apiClient.Gists.List(ctx, user, gistOpts)
@ -611,7 +599,7 @@ func (s *Source) addGistsByUser(ctx context.Context, apiClient *github.Client, u
func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client, apiClient *github.Client) error {
opts := &github.ListOptions{
PerPage: 500,
PerPage: membersAppPagination,
}
optsOrg := &github.ListMembersOptions{
PublicOnly: false,
@ -657,7 +645,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
func (s *Source) addReposByApp(ctx context.Context, apiClient *github.Client) error {
// Authenticated enumeration of repos
opts := &github.ListOptions{
PerPage: 100,
PerPage: defaultPagination,
}
for {
someRepos, res, err := apiClient.Apps.ListRepos(ctx, opts)
@ -685,14 +673,14 @@ func (s *Source) addReposByApp(ctx context.Context, apiClient *github.Client) er
}
func (s *Source) addAllVisibleOrgs(ctx context.Context, apiClient *github.Client) {
s.log.Debug("enumerating all visibile organizations on GHE")
// Enumeration on this endpoint does not use pages. it uses a since ID.
s.log.Debug("enumerating all visible organizations on GHE")
// Enumeration on this endpoint does not use pages it uses a since ID.
// The endpoint will return organizations with an ID greater than the given since ID.
// Empty org response is our cue to break the enumeration loop.
orgOpts := &github.OrganizationsListOptions{
Since: 0,
ListOptions: github.ListOptions{
PerPage: 100,
PerPage: defaultPagination,
},
}
for {
@ -731,7 +719,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context, apiClient *github.Client
func (s *Source) addOrgsByUser(ctx context.Context, apiClient *github.Client, user string) {
orgOpts := &github.ListOptions{
PerPage: 100,
PerPage: defaultPagination,
}
for {
orgs, resp, err := apiClient.Organizations.List(ctx, "", orgOpts)
@ -767,7 +755,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
// TODO: Add check/fix for repos that are missing scheme
normalizedRepos := map[string]struct{}{}
for _, repo := range s.repos {
// if there's a '/', assume it's a URL and try to normalize it
// If there's a '/', assume it's a URL and try to normalize it.
if strings.ContainsRune(repo, '/') {
repoNormalized, err := giturl.NormalizeGithubRepo(repo)
if err != nil {
@ -777,7 +765,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
normalizedRepos[repoNormalized] = struct{}{}
continue
}
// otherwise, assume it's a user and enumerate repositories and gists
// Otherwise, assume it's a user and enumerate repositories and gists.
if repos, err := s.getReposByUser(ctx, apiClient, repo); err == nil {
for _, repo := range repos {
normalizedRepos[repo] = struct{}{}
@ -790,7 +778,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
}
}
// replace s.repos
// Replace s.repos.
s.repos = s.repos[:0]
for key := range normalizedRepos {
s.repos = append(s.repos, key)

View file

@ -17,6 +17,7 @@ import (
"github.com/google/go-github/v42/github"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
"gopkg.in/h2non/gock.v1"
@ -67,7 +68,7 @@ func TestAddReposByOrg(t *testing.T) {
s := initTestSource(nil)
// gock works here because github.NewClient is using the default HTTP Transport
err := s.addReposByOrg(context.TODO(), github.NewClient(nil), "super-secret-org")
err := s.addRepos(context.TODO(), github.NewClient(nil), "super-secret-org", s.getReposByOrg)
assert.Nil(t, err)
assert.Equal(t, 1, len(s.repos))
assert.Equal(t, []string{"super-secret-repo"}, s.repos)
@ -83,7 +84,7 @@ func TestAddReposByUser(t *testing.T) {
JSON([]map[string]string{{"clone_url": "super-secret-repo"}})
s := initTestSource(nil)
err := s.addReposByUser(context.TODO(), github.NewClient(nil), "super-secret-user")
err := s.addRepos(context.TODO(), github.NewClient(nil), "super-secret-user", s.getReposByUser)
assert.Nil(t, err)
assert.Equal(t, 1, len(s.repos))
assert.Equal(t, []string{"super-secret-repo"}, s.repos)
@ -438,3 +439,41 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
}
}
}
func Test_scan_SetProgressComplete(t *testing.T) {
testCases := []struct {
name string
repos []string
wantComplete bool
wantErr bool
}{
{
name: "no repos",
wantComplete: true,
},
{
name: "one valid repo",
repos: []string{"a"},
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
src := &Source{
repos: tc.repos,
}
src.jobPool = &errgroup.Group{}
_ = src.scan(context.Background(), nil, nil)
if !tc.wantErr {
assert.Equal(t, "", src.GetProgress().EncodedResumeInfo)
}
gotComplete := src.GetProgress().PercentComplete == 100
if gotComplete != tc.wantComplete {
t.Errorf("got: %v, want: %v", gotComplete, tc.wantComplete)
}
})
}
}

View file

@ -332,6 +332,9 @@ func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk)
}(ctx, repo, i)
}
wg.Wait()
if len(errs) == 0 {
s.SetProgressComplete(len(s.repos), len(s.repos), fmt.Sprintf("Completed scanning source %s", s.name), "")
}
return errs
}

View file

@ -8,6 +8,8 @@ import (
"github.com/kylelemons/godebug/pretty"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
@ -247,3 +249,41 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
}
}
}
func Test_scanRepos_SetProgressComplete(t *testing.T) {
testCases := []struct {
name string
repos []string
wantComplete bool
wantErr bool
}{
{
name: "no repos",
wantComplete: true,
},
{
name: "one valid repo",
repos: []string{"a"},
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
src := &Source{
repos: tc.repos,
}
src.jobSem = semaphore.NewWeighted(1)
_ = src.scanRepos(context.Background(), nil)
if !tc.wantErr {
assert.Equal(t, "", src.GetProgress().EncodedResumeInfo)
}
gotComplete := src.GetProgress().PercentComplete == 100
if gotComplete != tc.wantComplete {
t.Errorf("got: %v, want: %v", gotComplete, tc.wantComplete)
}
})
}
}

View file

@ -157,8 +157,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
} else {
regionalClient = client
}
//Forced prefix for testing
//pf := "public"
// Forced prefix for testing
// pf := "public"
errorCount := sync.Map{}
err = regionalClient.ListObjectsV2PagesWithContext(
@ -172,8 +172,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
s.log.WithError(err).Errorf("could not list objects in s3 bucket: %s", bucket)
return errors.WrapPrefix(err, fmt.Sprintf("could not list objects in s3 bucket: %s", bucket), 0)
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s", s.name), "")
return nil
}
@ -197,12 +197,12 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
defer common.Recover(ctx)
defer sem.Release(1)
defer wg.Done()
//defer log.Debugf("DONE - %s", *obj.Key)
// defer log.Debugf("DONE - %s", *obj.Key)
if (*obj.Key)[len(*obj.Key)-1:] == "/" {
return
}
//log.Debugf("Object: %s", *obj.Key)
// log.Debugf("Object: %s", *obj.Key)
path := strings.Split(*obj.Key, "/")
prefix := strings.Join(path[:len(path)-1], "/")
@ -221,13 +221,13 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
return
}
//file is 0 bytes - likely no permissions - skipping
// file is 0 bytes - likely no permissions - skipping
if *obj.Size == 0 {
return
}
//files break with spaces, must replace with +
//objKey := strings.ReplaceAll(*obj.Key, " ", "+")
// files break with spaces, must replace with +
// objKey := strings.ReplaceAll(*obj.Key, " ", "+")
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
res, err := client.GetObjectWithContext(ctx, &s3.GetObjectInput{
@ -249,7 +249,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
}
nErr = nErr.(int) + 1
errorCount.Store(prefix, nErr)
//too many consective errors on this page
// too many consective errors on this page
if nErr.(int) > 3 {
s.log.Warnf("Too many consecutive errors. Excluding %s", prefix)
}

View file

@ -10,6 +10,7 @@ import (
"github.com/kylelemons/godebug/pretty"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
@ -93,6 +94,8 @@ func TestSource_Chunks(t *testing.T) {
if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
}
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
})
}
}

View file

@ -40,7 +40,7 @@ type Source interface {
Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error
// Chunks emits data over a channel that is decoded and scanned for secrets.
Chunks(ctx context.Context, chunksChan chan *Chunk) error
// Completion Percentage for Scanned Source
// GetProgress is the completion progress (percentage) for Scanned Source.
GetProgress() *Progress
}
@ -104,7 +104,7 @@ func NewConfig(opts ...func(*Config)) Config {
return *c
}
// PercentComplete is used to update job completion percentages across sources
// Progress is used to update job completion progress across sources.
type Progress struct {
mut sync.Mutex
PercentComplete int64
@ -127,10 +127,17 @@ func (p *Progress) SetProgressComplete(i, scope int, message, encodedResumeInfo
p.EncodedResumeInfo = encodedResumeInfo
p.SectionsCompleted = int32(i)
p.SectionsRemaining = int32(scope)
// If the iteration and scope are both 0, completion is 100%.
if i == 0 && scope == 0 {
p.PercentComplete = 100
return
}
p.PercentComplete = int64((float64(i) / float64(scope)) * 100)
}
// GetProgressComplete gets job completion percentage for metrics reporting
// GetProgress gets job completion percentage for metrics reporting.
func (p *Progress) GetProgress() *Progress {
p.mut.Lock()
defer p.mut.Unlock()