GitHub source logger clean up (#3269)

* GitHub source logger clean up

* applied pr comments

* applied pr comments

* applied pr comments

* applied PR review comments
This commit is contained in:
Nash 2024-09-09 15:44:56 -04:00 committed by GitHub
parent 8a4d62c670
commit 17f6c98119
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 50 additions and 54 deletions

View file

@ -17,14 +17,14 @@ type tokenConnector struct {
apiClient *github.Client
token string
isGitHubEnterprise bool
handleRateLimit func(error) bool
handleRateLimit func(context.Context, error) bool
user string
userMu sync.Mutex
}
var _ connector = (*tokenConnector)(nil)
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(error) bool) (*tokenConnector, error) {
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) {
const httpTimeoutSeconds = 60
httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds))
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
@ -68,7 +68,7 @@ func (c *tokenConnector) getUser(ctx context.Context) (string, error) {
)
for {
user, _, err = c.apiClient.Users.Get(ctx, "")
if c.handleRateLimit(err) {
if c.handleRateLimit(ctx, err) {
continue
}
if err != nil {

View file

@ -61,7 +61,6 @@ type Source struct {
scanOptMu sync.Mutex // protects the scanOptions
scanOptions *git.ScanOptions
log logr.Logger
conn *sourcespb.GitHub
jobPool *errgroup.Group
resumeInfoMutex sync.Mutex
@ -117,13 +116,13 @@ type filteredRepoCache struct {
include, exclude []glob.Glob
}
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
func (s *Source) newFilteredRepoCache(ctx context.Context, c cache.Cache[string], include, exclude []string) *filteredRepoCache {
includeGlobs := make([]glob.Glob, 0, len(include))
excludeGlobs := make([]glob.Glob, 0, len(exclude))
for _, ig := range include {
g, err := glob.Compile(ig)
if err != nil {
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
ctx.Logger().V(1).Info("invalid include glob", "include_value", ig, "err", err)
continue
}
includeGlobs = append(includeGlobs, g)
@ -131,7 +130,7 @@ func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []
for _, eg := range exclude {
g, err := glob.Compile(eg)
if err != nil {
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
ctx.Logger().V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
continue
}
excludeGlobs = append(excludeGlobs, g)
@ -180,8 +179,6 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
return err
}
s.log = aCtx.Logger()
s.name = name
s.sourceID = sourceID
s.jobID = jobID
@ -208,7 +205,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
}
s.memberCache = make(map[string]struct{})
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
s.filteredRepoCache = s.newFilteredRepoCache(aCtx,
memory.New[string](),
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
s.conn.GetIgnoreRepos(),
)
@ -360,7 +358,7 @@ RepoLoop:
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
if s.handleRateLimit(err) {
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
@ -374,7 +372,7 @@ RepoLoop:
// Cache repository info.
for {
ghRepo, _, err := s.connector.APIClient().Repositories.Get(repoCtx, urlParts[1], urlParts[2])
if s.handleRateLimit(err) {
if s.handleRateLimit(repoCtx, err) {
continue
}
if err != nil {
@ -389,8 +387,7 @@ RepoLoop:
s.repos = append(s.repos, repo)
}
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
ctx.Logger().Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
// We must sort the repos so we can resume later if necessary.
sort.Strings(s.repos)
return nil
@ -417,7 +414,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error {
func (s *Source) enumerateUnauthenticated(ctx context.Context) {
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}
for _, org := range s.orgsCache.Keys() {
@ -441,7 +438,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
var err error
for {
ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "")
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -454,10 +451,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
if !specificScope {
// Enumerate the user's orgs and repos if none were specified.
if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil {
s.log.Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
}
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil {
s.log.Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
}
if isGithubEnterprise {
@ -486,7 +483,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
}
if s.conn.ScanUsers && len(s.memberCache) > 0 {
s.log.Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
s.addReposForMembers(ctx)
}
}
@ -507,9 +504,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
if err != nil {
return err
}
s.log.Info("Scanning repos", "org_members", len(s.memberCache))
ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache))
for member := range s.memberCache {
logger := s.log.WithValues("member", member)
logger := ctx.Logger().WithValues("member", member)
if err := s.addUserGistsToCache(ctx, member); err != nil {
logger.Error(err, "error fetching gists by user")
}
@ -536,7 +533,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
var scannedCount uint64 = 1
s.log.V(2).Info("Found repos to scan", "count", len(s.repos))
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
// If there is resume information available, limit this scan to only the repos that still need scanning.
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
@ -574,7 +571,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
s.log.Error(err, "failed to scan repository")
ctx.Logger().Error(err, "failed to scan repository")
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
@ -618,7 +615,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
_ = s.jobPool.Wait()
if scanErrs.Count() > 0 {
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
ctx.Logger().Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
}
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
@ -666,7 +663,7 @@ var (
// Authenticated users have a rate limit of 5,000 requests per hour,
// however, certain actions are subject to a stricter "secondary" limit.
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
func (s *Source) handleRateLimit(errIn error) bool {
func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
if errIn == nil {
return false
}
@ -705,12 +702,12 @@ func (s *Source) handleRateLimit(errIn error) bool {
if retryAfter > 0 {
retryAfter = retryAfter + jitter
rateLimitResumeTime = now.Add(retryAfter)
s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
} else {
retryAfter = (5 * time.Minute) + jitter
rateLimitResumeTime = now.Add(retryAfter)
// TODO: Use exponential backoff instead of static retry time.
s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
ctx.Logger().Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
}
rateLimitMu.Unlock()
@ -725,13 +722,13 @@ func (s *Source) handleRateLimit(errIn error) bool {
}
func (s *Source) addReposForMembers(ctx context.Context) {
s.log.Info("Fetching repos from members", "members", len(s.memberCache))
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
for member := range s.memberCache {
if err := s.addUserGistsToCache(ctx, member); err != nil {
s.log.Info("Unable to fetch gists by user", "user", member, "error", err)
ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err)
}
if err := s.getReposByUser(ctx, member); err != nil {
s.log.Info("Unable to fetch repos by user", "user", member, "error", err)
ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err)
}
}
}
@ -740,10 +737,11 @@ func (s *Source) addReposForMembers(ctx context.Context) {
// and adds them to the filteredRepoCache.
func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
gistOpts := &github.GistListOptions{}
logger := s.log.WithValues("user", user)
logger := ctx.Logger().WithValues("user", user)
for {
gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -788,7 +786,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
}
func (s *Source) addAllVisibleOrgs(ctx context.Context) {
s.log.V(2).Info("enumerating all visible organizations on GHE")
ctx.Logger().V(2).Info("enumerating all visible organizations on GHE")
// Enumeration on this endpoint does not use pages it uses a since ID.
// The endpoint will return organizations with an ID greater than the given since ID.
// Empty org response is our cue to break the enumeration loop.
@ -800,11 +798,11 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}
for {
orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
s.log.Error(err, "could not list all organizations")
ctx.Logger().Error(err, "could not list all organizations")
return
}
@ -813,7 +811,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}
lastOrgID := *orgs[len(orgs)-1].ID
s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
ctx.Logger().V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
orgOpts.Since = lastOrgID
for _, org := range orgs {
@ -827,7 +825,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
continue
}
s.orgsCache.Set(name, name)
s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
ctx.Logger().V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
}
}
}
@ -836,10 +834,10 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
orgOpts := &github.ListOptions{
PerPage: defaultPagination,
}
logger := s.log.WithValues("user", user)
logger := ctx.Logger().WithValues("user", user)
for {
orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -869,10 +867,10 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
},
}
logger := s.log.WithValues("org", org)
logger := ctx.Logger().WithValues("org", org)
for {
members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil || len(members) == 0 {
@ -994,7 +992,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
}
for {
comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -1107,7 +1105,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
for {
issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
@ -1179,7 +1177,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
for {
issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -1244,7 +1242,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
for {
prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
@ -1276,7 +1274,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
for {
prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {

View file

@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/go-github/v63/github"
"github.com/stretchr/testify/assert"
@ -369,7 +368,8 @@ func TestNormalizeRepos(t *testing.T) {
func TestHandleRateLimit(t *testing.T) {
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
assert.False(t, s.handleRateLimit(nil))
ctx := context.Background()
assert.False(t, s.handleRateLimit(ctx, nil))
// Request
reqUrl, _ := url.Parse("https://github.com/trufflesecurity/trufflehog")
@ -400,7 +400,7 @@ func TestHandleRateLimit(t *testing.T) {
Message: "Too Many Requests",
}
assert.True(t, s.handleRateLimit(err))
assert.True(t, s.handleRateLimit(ctx, err))
}
func TestEnumerateUnauthenticated(t *testing.T) {
@ -721,7 +721,6 @@ func Test_setProgressCompleteWithRepo_resumeInfo(t *testing.T) {
s := &Source{
repos: []string{},
log: logr.Discard(),
}
for _, tt := range tests {
@ -772,7 +771,6 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
for _, tt := range tests {
s := &Source{
repos: tt.repos,
log: logr.Discard(),
}
s.setProgressCompleteWithRepo(tt.index, tt.offset, "")

View file

@ -181,7 +181,7 @@ func isGitHub404Error(err error) bool {
}
func (s *Source) processRepos(ctx context.Context, target string, listRepos repoLister, listOpts repoListOptions) error {
logger := s.log.WithValues("target", target)
logger := ctx.Logger().WithValues("target", target)
opts := listOpts.getListOptions()
var (
@ -191,14 +191,14 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
for {
someRepos, res, err := listRepos(ctx, target, listOpts)
if s.handleRateLimit(err) {
if s.handleRateLimit(ctx, err) {
continue
}
if err != nil {
return err
}
s.log.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage)
ctx.Logger().V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage)
for _, r := range someRepos {
if r.GetFork() {
if !s.conn.IncludeForks {