diff --git a/pkg/sources/github/connector_token.go b/pkg/sources/github/connector_token.go index 6fedac31f..f45ad9ca1 100644 --- a/pkg/sources/github/connector_token.go +++ b/pkg/sources/github/connector_token.go @@ -17,14 +17,14 @@ type tokenConnector struct { apiClient *github.Client token string isGitHubEnterprise bool - handleRateLimit func(error) bool + handleRateLimit func(context.Context, error) bool user string userMu sync.Mutex } var _ connector = (*tokenConnector)(nil) -func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(error) bool) (*tokenConnector, error) { +func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) { const httpTimeoutSeconds = 60 httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds)) tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) @@ -68,7 +68,7 @@ func (c *tokenConnector) getUser(ctx context.Context) (string, error) { ) for { user, _, err = c.apiClient.Users.Get(ctx, "") - if c.handleRateLimit(err) { + if c.handleRateLimit(ctx, err) { continue } if err != nil { diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index e201bb5bc..7ff93a97b 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -61,7 +61,6 @@ type Source struct { scanOptMu sync.Mutex // protects the scanOptions scanOptions *git.ScanOptions - log logr.Logger conn *sourcespb.GitHub jobPool *errgroup.Group resumeInfoMutex sync.Mutex @@ -117,13 +116,13 @@ type filteredRepoCache struct { include, exclude []glob.Glob } -func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache { +func (s *Source) newFilteredRepoCache(ctx context.Context, c cache.Cache[string], include, exclude []string) *filteredRepoCache { includeGlobs := make([]glob.Glob, 0, len(include)) excludeGlobs := make([]glob.Glob, 0, len(exclude)) for _, ig := range include { g, err := glob.Compile(ig) if err != nil { - s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err) + ctx.Logger().V(1).Info("invalid include glob", "include_value", ig, "err", err) continue } includeGlobs = append(includeGlobs, g) @@ -131,7 +130,7 @@ func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude [] for _, eg := range exclude { g, err := glob.Compile(eg) if err != nil { - s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err) + ctx.Logger().V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err) continue } excludeGlobs = append(excludeGlobs, g) @@ -180,8 +179,6 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so return err } - s.log = aCtx.Logger() - s.name = name s.sourceID = sourceID s.jobID = jobID @@ -208,7 +205,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so } s.memberCache = make(map[string]struct{}) - s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](), + s.filteredRepoCache = s.newFilteredRepoCache(aCtx, + memory.New[string](), append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...), s.conn.GetIgnoreRepos(), ) @@ -360,7 +358,7 @@ RepoLoop: // Normalize the URL to the Gist's pull URL. // See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937 repo = gist.GetGitPullURL() - if s.handleRateLimit(err) { + if s.handleRateLimit(repoCtx, err) { continue } if err != nil { @@ -374,7 +372,7 @@ RepoLoop: // Cache repository info. for { ghRepo, _, err := s.connector.APIClient().Repositories.Get(repoCtx, urlParts[1], urlParts[2]) - if s.handleRateLimit(err) { + if s.handleRateLimit(repoCtx, err) { continue } if err != nil { @@ -389,8 +387,7 @@ RepoLoop: s.repos = append(s.repos, repo) } githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos))) - s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache)) - + ctx.Logger().Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache)) // We must sort the repos so we can resume later if necessary. sort.Strings(s.repos) return nil @@ -417,7 +414,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error { func (s *Source) enumerateUnauthenticated(ctx context.Context) { if s.orgsCache.Count() > unauthGithubOrgRateLimt { - s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.") + ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.") } for _, org := range s.orgsCache.Keys() { @@ -441,7 +438,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool var err error for { ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "") - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -454,10 +451,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool if !specificScope { // Enumerate the user's orgs and repos if none were specified. if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil { - s.log.Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin()) + ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin()) } if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil { - s.log.Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin()) + ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin()) } if isGithubEnterprise { @@ -486,7 +483,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool } if s.conn.ScanUsers && len(s.memberCache) > 0 { - s.log.Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache)) + ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache)) s.addReposForMembers(ctx) } } @@ -507,9 +504,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu if err != nil { return err } - s.log.Info("Scanning repos", "org_members", len(s.memberCache)) + ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache)) for member := range s.memberCache { - logger := s.log.WithValues("member", member) + logger := ctx.Logger().WithValues("member", member) if err := s.addUserGistsToCache(ctx, member); err != nil { logger.Error(err, "error fetching gists by user") } @@ -536,7 +533,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error { var scannedCount uint64 = 1 - s.log.V(2).Info("Found repos to scan", "count", len(s.repos)) + ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos)) // If there is resume information available, limit this scan to only the repos that still need scanning. reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo) @@ -574,7 +571,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error if !ok { // This should never happen. err := fmt.Errorf("no repoInfo for URL: %s", repoURL) - s.log.Error(err, "failed to scan repository") + ctx.Logger().Error(err, "failed to scan repository") return nil } repoCtx := context.WithValues(ctx, "repo", repoURL) @@ -618,7 +615,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error _ = s.jobPool.Wait() if scanErrs.Count() > 0 { - s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String()) + ctx.Logger().Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String()) } s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "") @@ -666,7 +663,7 @@ var ( // Authenticated users have a rate limit of 5,000 requests per hour, // however, certain actions are subject to a stricter "secondary" limit. // https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api -func (s *Source) handleRateLimit(errIn error) bool { +func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool { if errIn == nil { return false } @@ -705,12 +702,12 @@ func (s *Source) handleRateLimit(errIn error) bool { if retryAfter > 0 { retryAfter = retryAfter + jitter rateLimitResumeTime = now.Add(retryAfter) - s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) + ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) } else { retryAfter = (5 * time.Minute) + jitter rateLimitResumeTime = now.Add(retryAfter) // TODO: Use exponential backoff instead of static retry time. - s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) + ctx.Logger().Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) } rateLimitMu.Unlock() @@ -725,13 +722,13 @@ func (s *Source) handleRateLimit(errIn error) bool { } func (s *Source) addReposForMembers(ctx context.Context) { - s.log.Info("Fetching repos from members", "members", len(s.memberCache)) + ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache)) for member := range s.memberCache { if err := s.addUserGistsToCache(ctx, member); err != nil { - s.log.Info("Unable to fetch gists by user", "user", member, "error", err) + ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err) } if err := s.getReposByUser(ctx, member); err != nil { - s.log.Info("Unable to fetch repos by user", "user", member, "error", err) + ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err) } } } @@ -740,10 +737,11 @@ func (s *Source) addReposForMembers(ctx context.Context) { // and adds them to the filteredRepoCache. func (s *Source) addUserGistsToCache(ctx context.Context, user string) error { gistOpts := &github.GistListOptions{} - logger := s.log.WithValues("user", user) + logger := ctx.Logger().WithValues("user", user) + for { gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -788,7 +786,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github } func (s *Source) addAllVisibleOrgs(ctx context.Context) { - s.log.V(2).Info("enumerating all visible organizations on GHE") + ctx.Logger().V(2).Info("enumerating all visible organizations on GHE") // Enumeration on this endpoint does not use pages it uses a since ID. // The endpoint will return organizations with an ID greater than the given since ID. // Empty org response is our cue to break the enumeration loop. @@ -800,11 +798,11 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } for { orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { - s.log.Error(err, "could not list all organizations") + ctx.Logger().Error(err, "could not list all organizations") return } @@ -813,7 +811,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } lastOrgID := *orgs[len(orgs)-1].ID - s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID)) + ctx.Logger().V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID)) orgOpts.Since = lastOrgID for _, org := range orgs { @@ -827,7 +825,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { continue } s.orgsCache.Set(name, name) - s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name) + ctx.Logger().V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name) } } } @@ -836,10 +834,10 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { orgOpts := &github.ListOptions{ PerPage: defaultPagination, } - logger := s.log.WithValues("user", user) + logger := ctx.Logger().WithValues("user", user) for { orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -869,10 +867,10 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { }, } - logger := s.log.WithValues("org", org) + logger := ctx.Logger().WithValues("org", org) for { members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil || len(members) == 0 { @@ -994,7 +992,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar } for { comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -1107,7 +1105,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha for { issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } @@ -1179,7 +1177,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch for { issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -1244,7 +1242,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c for { prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { @@ -1276,7 +1274,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk for { prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 729f5bf0a..c35c9a088 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "github.com/google/go-github/v63/github" "github.com/stretchr/testify/assert" @@ -369,7 +368,8 @@ func TestNormalizeRepos(t *testing.T) { func TestHandleRateLimit(t *testing.T) { s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}}) - assert.False(t, s.handleRateLimit(nil)) + ctx := context.Background() + assert.False(t, s.handleRateLimit(ctx, nil)) // Request reqUrl, _ := url.Parse("https://github.com/trufflesecurity/trufflehog") @@ -400,7 +400,7 @@ func TestHandleRateLimit(t *testing.T) { Message: "Too Many Requests", } - assert.True(t, s.handleRateLimit(err)) + assert.True(t, s.handleRateLimit(ctx, err)) } func TestEnumerateUnauthenticated(t *testing.T) { @@ -721,7 +721,6 @@ func Test_setProgressCompleteWithRepo_resumeInfo(t *testing.T) { s := &Source{ repos: []string{}, - log: logr.Discard(), } for _, tt := range tests { @@ -772,7 +771,6 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) { for _, tt := range tests { s := &Source{ repos: tt.repos, - log: logr.Discard(), } s.setProgressCompleteWithRepo(tt.index, tt.offset, "") diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index 79be8653f..f584bc8d2 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -181,7 +181,7 @@ func isGitHub404Error(err error) bool { } func (s *Source) processRepos(ctx context.Context, target string, listRepos repoLister, listOpts repoListOptions) error { - logger := s.log.WithValues("target", target) + logger := ctx.Logger().WithValues("target", target) opts := listOpts.getListOptions() var ( @@ -191,14 +191,14 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo for { someRepos, res, err := listRepos(ctx, target, listOpts) - if s.handleRateLimit(err) { + if s.handleRateLimit(ctx, err) { continue } if err != nil { return err } - s.log.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage) + ctx.Logger().V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage) for _, r := range someRepos { if r.GetFork() { if !s.conn.IncludeForks {