mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 15:14:38 +00:00
[THOG-681] - Handle errors sources (#783)
* Handle errors w/ github source. * Fix loop var captured by func literal. * Fix loop var captured by func literal. * Set completed progress if the scan completes with no errors. * Set progress to 100% if the scope and iteration are both 0. * Fix commentary. * Fix test. * Return after the defer to os.RemoveAll. * Fix unauth scan. * Inline range loop. * update tests for partial scan completion with errors. Ensure correct progress is set. * Update progress for all sources. * Update github test. * Address comments.
This commit is contained in:
parent
c12be4d98d
commit
7ba583ca40
8 changed files with 199 additions and 118 deletions
|
@ -195,6 +195,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
}
|
||||
|
||||
}
|
||||
s.SetProgressComplete(len(s.conn.Repositories), len(s.conn.Repositories), fmt.Sprintf("Completed scanning source %s", s.name), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
"github.com/google/go-github/v42/github"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
|
@ -34,6 +34,12 @@ import (
|
|||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
|
||||
)
|
||||
|
||||
const (
|
||||
unauthGithubOrgRateLimt = 30
|
||||
defaultPagination = 100
|
||||
membersAppPagination = 500
|
||||
)
|
||||
|
||||
type Source struct {
|
||||
name string
|
||||
sourceID int64
|
||||
|
@ -48,7 +54,7 @@ type Source struct {
|
|||
log *log.Entry
|
||||
token string
|
||||
conn *sourcespb.GitHub
|
||||
jobSem *semaphore.Weighted
|
||||
jobPool *errgroup.Group
|
||||
resumeInfoSlice []string
|
||||
resumeInfoMutex sync.Mutex
|
||||
sources.Progress
|
||||
|
@ -103,7 +109,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64,
|
|||
s.sourceID = sourceID
|
||||
s.jobID = jobID
|
||||
s.verify = verify
|
||||
s.jobSem = semaphore.NewWeighted(int64(concurrency))
|
||||
s.jobPool = &errgroup.Group{}
|
||||
s.jobPool.SetLimit(concurrency)
|
||||
|
||||
s.httpClient = common.SaneHttpClient()
|
||||
|
||||
|
@ -144,13 +151,13 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64,
|
|||
|
||||
func (s *Source) enumerateUnauthenticated(ctx context.Context) *github.Client {
|
||||
apiClient := github.NewClient(s.httpClient)
|
||||
if len(s.orgs) > 30 {
|
||||
if len(s.orgs) > unauthGithubOrgRateLimt {
|
||||
log.Warn("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
|
||||
}
|
||||
|
||||
for _, org := range s.orgs {
|
||||
errOrg := s.addReposByOrg(ctx, apiClient, org)
|
||||
errUser := s.addReposByUser(ctx, apiClient, org)
|
||||
errOrg := s.addRepos(ctx, apiClient, org, s.getReposByOrg)
|
||||
errUser := s.addRepos(ctx, apiClient, org, s.getReposByUser)
|
||||
if errOrg != nil && errUser != nil {
|
||||
log.WithError(errOrg).Error("error fetching repos for org or user: ", org)
|
||||
}
|
||||
|
@ -159,18 +166,18 @@ func (s *Source) enumerateUnauthenticated(ctx context.Context) *github.Client {
|
|||
}
|
||||
|
||||
func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token string) (*github.Client, error) {
|
||||
// needed for clones
|
||||
// Needed for clones.
|
||||
s.token = token
|
||||
|
||||
// needed to list repos
|
||||
// Needed to list repos.
|
||||
ts := oauth2.StaticTokenSource(
|
||||
&oauth2.Token{AccessToken: token},
|
||||
)
|
||||
tc := oauth2.NewClient(context.TODO(), ts)
|
||||
|
||||
var err error
|
||||
// If we're using public github, make a regular client.
|
||||
// Otherwise make an enterprise client
|
||||
// If we're using public Github, make a regular client.
|
||||
// Otherwise, make an enterprise client.
|
||||
var isGHE bool
|
||||
var apiClient *github.Client
|
||||
if apiEndpoint == "https://api.github.com" {
|
||||
|
@ -191,25 +198,25 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
|
|||
specificScope = true
|
||||
}
|
||||
|
||||
user, _, err := apiClient.Users.Get(context.TODO(), "")
|
||||
if err != nil {
|
||||
return nil, errors.New(err)
|
||||
}
|
||||
|
||||
if len(s.orgs) > 0 {
|
||||
specificScope = true
|
||||
for _, org := range s.orgs {
|
||||
errOrg := s.addReposByOrg(ctx, apiClient, org)
|
||||
errUser := s.addReposByUser(ctx, apiClient, org)
|
||||
errOrg := s.addRepos(ctx, apiClient, org, s.getReposByOrg)
|
||||
errUser := s.addRepos(ctx, apiClient, user.GetLogin(), s.getReposByUser)
|
||||
if errOrg != nil && errUser != nil {
|
||||
log.WithError(errOrg).Error("error fetching repos for org or user: ", org)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user, _, err := apiClient.Users.Get(context.TODO(), "")
|
||||
if err != nil {
|
||||
return nil, errors.New(err)
|
||||
}
|
||||
|
||||
// If no scope was provided, enumerate them
|
||||
// If no scope was provided, enumerate them.
|
||||
if !specificScope {
|
||||
if err := s.addReposByUser(ctx, apiClient, user.GetLogin()); err != nil {
|
||||
if err := s.addRepos(ctx, apiClient, user.GetLogin(), s.getReposByUser); err != nil {
|
||||
log.WithError(err).Error("error fetching repos by user")
|
||||
}
|
||||
|
||||
|
@ -222,7 +229,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
|
|||
}
|
||||
|
||||
for _, org := range s.orgs {
|
||||
if err := s.addReposByOrg(ctx, apiClient, org); err != nil {
|
||||
if err := s.addRepos(ctx, apiClient, org, s.getReposByOrg); err != nil {
|
||||
log.WithError(err).Error("error fetching repos by org")
|
||||
}
|
||||
}
|
||||
|
@ -251,7 +258,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
|
|||
return nil, nil, errors.New(err)
|
||||
}
|
||||
|
||||
// This client is used for most APIs
|
||||
// This client is used for most APIs.
|
||||
itr, err := ghinstallation.New(
|
||||
s.httpClient.Transport,
|
||||
appID,
|
||||
|
@ -266,8 +273,8 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
|
|||
return nil, nil, errors.New(err)
|
||||
}
|
||||
|
||||
// This client is required to create installation tokens for cloning.. Otherwise the required JWT is not in the
|
||||
// request for the token :/
|
||||
// This client is required to create installation tokens for cloning.
|
||||
// Otherwise, the required JWT is not in the request for the token :/
|
||||
appItr, err := ghinstallation.NewAppsTransport(
|
||||
s.httpClient.Transport,
|
||||
appID,
|
||||
|
@ -281,14 +288,13 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
|
|||
return nil, nil, errors.New(err)
|
||||
}
|
||||
|
||||
// If no repos were provided, enumerate them
|
||||
// If no repos were provided, enumerate them.
|
||||
if len(s.repos) == 0 {
|
||||
err = s.addReposByApp(ctx, apiClient)
|
||||
if err != nil {
|
||||
if err = s.addReposByApp(ctx, apiClient); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// check if we need to find user repos
|
||||
// Check if we need to find user repos.
|
||||
if s.conn.ScanUsers {
|
||||
err := s.addMembersByApp(ctx, installationClient, apiClient)
|
||||
if err != nil {
|
||||
|
@ -296,8 +302,10 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
|
|||
}
|
||||
log.Infof("Scanning repos from %v organization members.", len(s.members))
|
||||
for _, member := range s.members {
|
||||
s.addGistsByUser(ctx, apiClient, member)
|
||||
if err := s.addReposByUser(ctx, apiClient, member); err != nil {
|
||||
if err = s.addGistsByUser(ctx, apiClient, member); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := s.addRepos(ctx, apiClient, member, s.getReposByUser); err != nil {
|
||||
log.WithError(err).Error("error fetching repos by user")
|
||||
}
|
||||
}
|
||||
|
@ -315,17 +323,16 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
}
|
||||
|
||||
var apiClient, installationClient *github.Client
|
||||
var err error
|
||||
|
||||
switch cred := s.conn.GetCredential().(type) {
|
||||
case *sourcespb.GitHub_Unauthenticated:
|
||||
apiClient = s.enumerateUnauthenticated(ctx)
|
||||
case *sourcespb.GitHub_Token:
|
||||
var err error
|
||||
if apiClient, err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sourcespb.GitHub_GithubApp:
|
||||
var err error
|
||||
if apiClient, installationClient, err = s.enumerateWithApp(ctx, apiEndpoint, cred.GithubApp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -339,51 +346,41 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
// We must sort the repos so we can resume later if necessary.
|
||||
sort.Strings(s.repos)
|
||||
|
||||
return s.scan(ctx, installationClient, chunksChan)
|
||||
for _, err := range s.scan(ctx, installationClient, chunksChan) {
|
||||
log.WithError(err).Error("error scanning repository")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
|
||||
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) []error {
|
||||
var scanned uint64
|
||||
|
||||
log.Debugf("Found %v total repos to scan", len(s.repos))
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make(chan error, 1)
|
||||
reportErr := func(err error) {
|
||||
// save the error if there's room, otherwise log and drop it
|
||||
select {
|
||||
case errs <- err:
|
||||
default:
|
||||
log.WithError(err).Warn("dropping error")
|
||||
}
|
||||
}
|
||||
|
||||
// If there is resume information available, limit this scan to only the repos that still need scanning.
|
||||
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
|
||||
s.repos = reposToScan
|
||||
|
||||
var scanErrs []error
|
||||
for i, repoURL := range s.repos {
|
||||
if err := s.jobSem.Acquire(ctx, 1); err != nil {
|
||||
// Acquire blocks until it can acquire the semaphore or returns an
|
||||
// error if the context is finished
|
||||
log.WithError(err).Debug("could not acquire semaphore")
|
||||
reportErr(err)
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(ctx context.Context, repoURL string, i int) {
|
||||
defer s.jobSem.Release(1)
|
||||
defer wg.Done()
|
||||
repoURL := repoURL
|
||||
s.jobPool.Go(func() error {
|
||||
if common.IsDone(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
|
||||
// Ensure the repo is removed from the resume info after being scanned.
|
||||
defer func(s *Source) {
|
||||
defer func(s *Source, repoURL string) {
|
||||
s.resumeInfoMutex.Lock()
|
||||
defer s.resumeInfoMutex.Unlock()
|
||||
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
|
||||
}(s)
|
||||
}(s, repoURL)
|
||||
|
||||
if !strings.HasSuffix(repoURL, ".git") {
|
||||
return
|
||||
scanErrs = append(scanErrs, fmt.Errorf("repo %s does not end in .git", repoURL))
|
||||
return nil
|
||||
}
|
||||
|
||||
s.log.WithField("repo", repoURL).Debugf("attempting to clone repo %d/%d", i+1, len(s.repos))
|
||||
|
@ -394,20 +391,24 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
|
|||
switch s.conn.GetCredential().(type) {
|
||||
case *sourcespb.GitHub_Unauthenticated:
|
||||
path, repo, err = git.CloneRepoUsingUnauthenticated(repoURL)
|
||||
if err != nil {
|
||||
scanErrs = append(scanErrs, fmt.Errorf("error cloning repo %s: %w", repoURL, err))
|
||||
}
|
||||
default:
|
||||
var token string
|
||||
token, err = s.Token(ctx, installationClient)
|
||||
if err != nil {
|
||||
reportErr(err)
|
||||
return
|
||||
scanErrs = append(scanErrs, fmt.Errorf("error getting token for repo %s: %w", repoURL, err))
|
||||
}
|
||||
path, repo, err = git.CloneRepoUsingToken(token, repoURL, "")
|
||||
if err != nil {
|
||||
scanErrs = append(scanErrs, fmt.Errorf("error cloning repo %s: %w", repoURL, err))
|
||||
}
|
||||
path, repo, err = git.CloneRepoUsingToken(token, repoURL, "clone")
|
||||
}
|
||||
|
||||
defer os.RemoveAll(path)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("unable to clone repo (%s), continuing", repoURL)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
// Base and head will only exist from incoming webhooks.
|
||||
scanOptions := git.NewScanOptions(
|
||||
|
@ -415,24 +416,23 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
|
|||
git.ScanOptionHeadCommit(s.conn.Head),
|
||||
)
|
||||
|
||||
err = s.git.ScanRepo(ctx, repo, path, scanOptions, chunksChan)
|
||||
if err != nil {
|
||||
if err = s.git.ScanRepo(ctx, repo, path, scanOptions, chunksChan); err != nil {
|
||||
log.WithError(err).Errorf("unable to scan repo, continuing")
|
||||
return nil
|
||||
}
|
||||
atomic.AddUint64(&scanned, 1)
|
||||
log.Debugf("scanned %d/%d repos", scanned, len(s.repos))
|
||||
}(ctx, repoURL, i)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// This only returns first error which is what we did prior to concurrency
|
||||
select {
|
||||
case err := <-errs:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
_ = s.jobPool.Wait()
|
||||
if len(scanErrs) == 0 {
|
||||
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed Github scan", "")
|
||||
}
|
||||
|
||||
return scanErrs
|
||||
}
|
||||
|
||||
// handleRateLimit returns true if a rate limit was handled
|
||||
|
@ -472,12 +472,12 @@ func handleRateLimit(errIn error, res *github.Response) bool {
|
|||
}
|
||||
|
||||
func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, org string) ([]string, error) {
|
||||
log := s.log.WithField("org", org)
|
||||
logger := s.log.WithField("org", org)
|
||||
|
||||
repos := []string{}
|
||||
var repos []string
|
||||
opts := &github.RepositoryListByOrgOptions{
|
||||
ListOptions: github.ListOptions{
|
||||
PerPage: 100,
|
||||
PerPage: defaultPagination,
|
||||
},
|
||||
}
|
||||
var numRepos, numForks int
|
||||
|
@ -492,7 +492,7 @@ func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, or
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("could not list repos for org %s: %w", org, err)
|
||||
}
|
||||
log.Debugf("listed repos page %d/%d", opts.Page, res.LastPage)
|
||||
logger.Debugf("listed repos page %d/%d", opts.Page, res.LastPage)
|
||||
if len(someRepos) == 0 {
|
||||
break
|
||||
}
|
||||
|
@ -511,16 +511,16 @@ func (s *Source) getReposByOrg(ctx context.Context, apiClient *github.Client, or
|
|||
}
|
||||
opts.Page = res.NextPage
|
||||
}
|
||||
log.Debugf("found %d repos (%d forks)", numRepos, numForks)
|
||||
logger.Debugf("found %d repos (%d forks)", numRepos, numForks)
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
func (s *Source) addReposByOrg(ctx context.Context, apiClient *github.Client, org string) error {
|
||||
repos, err := s.getReposByOrg(ctx, apiClient, org)
|
||||
func (s *Source) addRepos(ctx context.Context, client *github.Client, entity string, getRepos func(context.Context, *github.Client, string) ([]string, error)) error {
|
||||
repos, err := getRepos(ctx, client, entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// add the repos to the set of repos
|
||||
// Add the repos to the set of repos.
|
||||
for _, repo := range repos {
|
||||
common.AddStringSliceItem(repo, &s.repos)
|
||||
}
|
||||
|
@ -528,7 +528,7 @@ func (s *Source) addReposByOrg(ctx context.Context, apiClient *github.Client, or
|
|||
}
|
||||
|
||||
func (s *Source) getReposByUser(ctx context.Context, apiClient *github.Client, user string) ([]string, error) {
|
||||
repos := []string{}
|
||||
var repos []string
|
||||
opts := &github.RepositoryListOptions{
|
||||
ListOptions: github.ListOptions{
|
||||
PerPage: 50,
|
||||
|
@ -559,20 +559,8 @@ func (s *Source) getReposByUser(ctx context.Context, apiClient *github.Client, u
|
|||
return repos, nil
|
||||
}
|
||||
|
||||
func (s *Source) addReposByUser(ctx context.Context, apiClient *github.Client, user string) error {
|
||||
repos, err := s.getReposByUser(ctx, apiClient, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// add the repos to the set of repos
|
||||
for _, repo := range repos {
|
||||
common.AddStringSliceItem(repo, &s.repos)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) getGistsByUser(ctx context.Context, apiClient *github.Client, user string) ([]string, error) {
|
||||
gistURLs := []string{}
|
||||
var gistURLs []string
|
||||
gistOpts := &github.GistListOptions{}
|
||||
for {
|
||||
gists, resp, err := apiClient.Gists.List(ctx, user, gistOpts)
|
||||
|
@ -611,7 +599,7 @@ func (s *Source) addGistsByUser(ctx context.Context, apiClient *github.Client, u
|
|||
|
||||
func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client, apiClient *github.Client) error {
|
||||
opts := &github.ListOptions{
|
||||
PerPage: 500,
|
||||
PerPage: membersAppPagination,
|
||||
}
|
||||
optsOrg := &github.ListMembersOptions{
|
||||
PublicOnly: false,
|
||||
|
@ -657,7 +645,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
|
|||
func (s *Source) addReposByApp(ctx context.Context, apiClient *github.Client) error {
|
||||
// Authenticated enumeration of repos
|
||||
opts := &github.ListOptions{
|
||||
PerPage: 100,
|
||||
PerPage: defaultPagination,
|
||||
}
|
||||
for {
|
||||
someRepos, res, err := apiClient.Apps.ListRepos(ctx, opts)
|
||||
|
@ -685,14 +673,14 @@ func (s *Source) addReposByApp(ctx context.Context, apiClient *github.Client) er
|
|||
}
|
||||
|
||||
func (s *Source) addAllVisibleOrgs(ctx context.Context, apiClient *github.Client) {
|
||||
s.log.Debug("enumerating all visibile organizations on GHE")
|
||||
// Enumeration on this endpoint does not use pages. it uses a since ID.
|
||||
s.log.Debug("enumerating all visible organizations on GHE")
|
||||
// Enumeration on this endpoint does not use pages it uses a since ID.
|
||||
// The endpoint will return organizations with an ID greater than the given since ID.
|
||||
// Empty org response is our cue to break the enumeration loop.
|
||||
orgOpts := &github.OrganizationsListOptions{
|
||||
Since: 0,
|
||||
ListOptions: github.ListOptions{
|
||||
PerPage: 100,
|
||||
PerPage: defaultPagination,
|
||||
},
|
||||
}
|
||||
for {
|
||||
|
@ -731,7 +719,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context, apiClient *github.Client
|
|||
|
||||
func (s *Source) addOrgsByUser(ctx context.Context, apiClient *github.Client, user string) {
|
||||
orgOpts := &github.ListOptions{
|
||||
PerPage: 100,
|
||||
PerPage: defaultPagination,
|
||||
}
|
||||
for {
|
||||
orgs, resp, err := apiClient.Organizations.List(ctx, "", orgOpts)
|
||||
|
@ -767,7 +755,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
|
|||
// TODO: Add check/fix for repos that are missing scheme
|
||||
normalizedRepos := map[string]struct{}{}
|
||||
for _, repo := range s.repos {
|
||||
// if there's a '/', assume it's a URL and try to normalize it
|
||||
// If there's a '/', assume it's a URL and try to normalize it.
|
||||
if strings.ContainsRune(repo, '/') {
|
||||
repoNormalized, err := giturl.NormalizeGithubRepo(repo)
|
||||
if err != nil {
|
||||
|
@ -777,7 +765,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
|
|||
normalizedRepos[repoNormalized] = struct{}{}
|
||||
continue
|
||||
}
|
||||
// otherwise, assume it's a user and enumerate repositories and gists
|
||||
// Otherwise, assume it's a user and enumerate repositories and gists.
|
||||
if repos, err := s.getReposByUser(ctx, apiClient, repo); err == nil {
|
||||
for _, repo := range repos {
|
||||
normalizedRepos[repo] = struct{}{}
|
||||
|
@ -790,7 +778,7 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
|
|||
}
|
||||
}
|
||||
|
||||
// replace s.repos
|
||||
// Replace s.repos.
|
||||
s.repos = s.repos[:0]
|
||||
for key := range normalizedRepos {
|
||||
s.repos = append(s.repos, key)
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/google/go-github/v42/github"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"gopkg.in/h2non/gock.v1"
|
||||
|
||||
|
@ -67,7 +68,7 @@ func TestAddReposByOrg(t *testing.T) {
|
|||
|
||||
s := initTestSource(nil)
|
||||
// gock works here because github.NewClient is using the default HTTP Transport
|
||||
err := s.addReposByOrg(context.TODO(), github.NewClient(nil), "super-secret-org")
|
||||
err := s.addRepos(context.TODO(), github.NewClient(nil), "super-secret-org", s.getReposByOrg)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(s.repos))
|
||||
assert.Equal(t, []string{"super-secret-repo"}, s.repos)
|
||||
|
@ -83,7 +84,7 @@ func TestAddReposByUser(t *testing.T) {
|
|||
JSON([]map[string]string{{"clone_url": "super-secret-repo"}})
|
||||
|
||||
s := initTestSource(nil)
|
||||
err := s.addReposByUser(context.TODO(), github.NewClient(nil), "super-secret-user")
|
||||
err := s.addRepos(context.TODO(), github.NewClient(nil), "super-secret-user", s.getReposByUser)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(s.repos))
|
||||
assert.Equal(t, []string{"super-secret-repo"}, s.repos)
|
||||
|
@ -438,3 +439,41 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_scan_SetProgressComplete(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
repos []string
|
||||
wantComplete bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no repos",
|
||||
wantComplete: true,
|
||||
},
|
||||
{
|
||||
name: "one valid repo",
|
||||
repos: []string{"a"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
src := &Source{
|
||||
repos: tc.repos,
|
||||
}
|
||||
src.jobPool = &errgroup.Group{}
|
||||
|
||||
_ = src.scan(context.Background(), nil, nil)
|
||||
if !tc.wantErr {
|
||||
assert.Equal(t, "", src.GetProgress().EncodedResumeInfo)
|
||||
}
|
||||
|
||||
gotComplete := src.GetProgress().PercentComplete == 100
|
||||
if gotComplete != tc.wantComplete {
|
||||
t.Errorf("got: %v, want: %v", gotComplete, tc.wantComplete)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -332,6 +332,9 @@ func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk)
|
|||
}(ctx, repo, i)
|
||||
}
|
||||
wg.Wait()
|
||||
if len(errs) == 0 {
|
||||
s.SetProgressComplete(len(s.repos), len(s.repos), fmt.Sprintf("Completed scanning source %s", s.name), "")
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
|
@ -247,3 +249,41 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_scanRepos_SetProgressComplete(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
repos []string
|
||||
wantComplete bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no repos",
|
||||
wantComplete: true,
|
||||
},
|
||||
{
|
||||
name: "one valid repo",
|
||||
repos: []string{"a"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
src := &Source{
|
||||
repos: tc.repos,
|
||||
}
|
||||
src.jobSem = semaphore.NewWeighted(1)
|
||||
|
||||
_ = src.scanRepos(context.Background(), nil)
|
||||
if !tc.wantErr {
|
||||
assert.Equal(t, "", src.GetProgress().EncodedResumeInfo)
|
||||
}
|
||||
|
||||
gotComplete := src.GetProgress().PercentComplete == 100
|
||||
if gotComplete != tc.wantComplete {
|
||||
t.Errorf("got: %v, want: %v", gotComplete, tc.wantComplete)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -157,8 +157,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
} else {
|
||||
regionalClient = client
|
||||
}
|
||||
//Forced prefix for testing
|
||||
//pf := "public"
|
||||
// Forced prefix for testing
|
||||
// pf := "public"
|
||||
errorCount := sync.Map{}
|
||||
|
||||
err = regionalClient.ListObjectsV2PagesWithContext(
|
||||
|
@ -172,8 +172,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
s.log.WithError(err).Errorf("could not list objects in s3 bucket: %s", bucket)
|
||||
return errors.WrapPrefix(err, fmt.Sprintf("could not list objects in s3 bucket: %s", bucket), 0)
|
||||
}
|
||||
|
||||
}
|
||||
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s", s.name), "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -197,12 +197,12 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
|
|||
defer common.Recover(ctx)
|
||||
defer sem.Release(1)
|
||||
defer wg.Done()
|
||||
//defer log.Debugf("DONE - %s", *obj.Key)
|
||||
// defer log.Debugf("DONE - %s", *obj.Key)
|
||||
|
||||
if (*obj.Key)[len(*obj.Key)-1:] == "/" {
|
||||
return
|
||||
}
|
||||
//log.Debugf("Object: %s", *obj.Key)
|
||||
// log.Debugf("Object: %s", *obj.Key)
|
||||
|
||||
path := strings.Split(*obj.Key, "/")
|
||||
prefix := strings.Join(path[:len(path)-1], "/")
|
||||
|
@ -221,13 +221,13 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
|
|||
return
|
||||
}
|
||||
|
||||
//file is 0 bytes - likely no permissions - skipping
|
||||
// file is 0 bytes - likely no permissions - skipping
|
||||
if *obj.Size == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
//files break with spaces, must replace with +
|
||||
//objKey := strings.ReplaceAll(*obj.Key, " ", "+")
|
||||
// files break with spaces, must replace with +
|
||||
// objKey := strings.ReplaceAll(*obj.Key, " ", "+")
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
res, err := client.GetObjectWithContext(ctx, &s3.GetObjectInput{
|
||||
|
@ -249,7 +249,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
|
|||
}
|
||||
nErr = nErr.(int) + 1
|
||||
errorCount.Store(prefix, nErr)
|
||||
//too many consective errors on this page
|
||||
// too many consective errors on this page
|
||||
if nErr.(int) > 3 {
|
||||
s.log.Warnf("Too many consecutive errors. Excluding %s", prefix)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
|
@ -93,6 +94,8 @@ func TestSource_Chunks(t *testing.T) {
|
|||
if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
|
||||
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
|
||||
}
|
||||
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
|
||||
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ type Source interface {
|
|||
Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error
|
||||
// Chunks emits data over a channel that is decoded and scanned for secrets.
|
||||
Chunks(ctx context.Context, chunksChan chan *Chunk) error
|
||||
// Completion Percentage for Scanned Source
|
||||
// GetProgress is the completion progress (percentage) for Scanned Source.
|
||||
GetProgress() *Progress
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ func NewConfig(opts ...func(*Config)) Config {
|
|||
return *c
|
||||
}
|
||||
|
||||
// PercentComplete is used to update job completion percentages across sources
|
||||
// Progress is used to update job completion progress across sources.
|
||||
type Progress struct {
|
||||
mut sync.Mutex
|
||||
PercentComplete int64
|
||||
|
@ -127,10 +127,17 @@ func (p *Progress) SetProgressComplete(i, scope int, message, encodedResumeInfo
|
|||
p.EncodedResumeInfo = encodedResumeInfo
|
||||
p.SectionsCompleted = int32(i)
|
||||
p.SectionsRemaining = int32(scope)
|
||||
|
||||
// If the iteration and scope are both 0, completion is 100%.
|
||||
if i == 0 && scope == 0 {
|
||||
p.PercentComplete = 100
|
||||
return
|
||||
}
|
||||
|
||||
p.PercentComplete = int64((float64(i) / float64(scope)) * 100)
|
||||
}
|
||||
|
||||
// GetProgressComplete gets job completion percentage for metrics reporting
|
||||
// GetProgress gets job completion percentage for metrics reporting.
|
||||
func (p *Progress) GetProgress() *Progress {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
|
|
Loading…
Reference in a new issue