mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
GitHub source logger clean up (#3269)
* GitHub source logger clean up * applied pr comments * applied pr comments * applied pr comments * applied PR review comments
This commit is contained in:
parent
8a4d62c670
commit
17f6c98119
4 changed files with 50 additions and 54 deletions
|
@ -17,14 +17,14 @@ type tokenConnector struct {
|
|||
apiClient *github.Client
|
||||
token string
|
||||
isGitHubEnterprise bool
|
||||
handleRateLimit func(error) bool
|
||||
handleRateLimit func(context.Context, error) bool
|
||||
user string
|
||||
userMu sync.Mutex
|
||||
}
|
||||
|
||||
var _ connector = (*tokenConnector)(nil)
|
||||
|
||||
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(error) bool) (*tokenConnector, error) {
|
||||
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) {
|
||||
const httpTimeoutSeconds = 60
|
||||
httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds))
|
||||
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
|
||||
|
@ -68,7 +68,7 @@ func (c *tokenConnector) getUser(ctx context.Context) (string, error) {
|
|||
)
|
||||
for {
|
||||
user, _, err = c.apiClient.Users.Get(ctx, "")
|
||||
if c.handleRateLimit(err) {
|
||||
if c.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
|
|
@ -61,7 +61,6 @@ type Source struct {
|
|||
scanOptMu sync.Mutex // protects the scanOptions
|
||||
scanOptions *git.ScanOptions
|
||||
|
||||
log logr.Logger
|
||||
conn *sourcespb.GitHub
|
||||
jobPool *errgroup.Group
|
||||
resumeInfoMutex sync.Mutex
|
||||
|
@ -117,13 +116,13 @@ type filteredRepoCache struct {
|
|||
include, exclude []glob.Glob
|
||||
}
|
||||
|
||||
func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []string) *filteredRepoCache {
|
||||
func (s *Source) newFilteredRepoCache(ctx context.Context, c cache.Cache[string], include, exclude []string) *filteredRepoCache {
|
||||
includeGlobs := make([]glob.Glob, 0, len(include))
|
||||
excludeGlobs := make([]glob.Glob, 0, len(exclude))
|
||||
for _, ig := range include {
|
||||
g, err := glob.Compile(ig)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
|
||||
ctx.Logger().V(1).Info("invalid include glob", "include_value", ig, "err", err)
|
||||
continue
|
||||
}
|
||||
includeGlobs = append(includeGlobs, g)
|
||||
|
@ -131,7 +130,7 @@ func (s *Source) newFilteredRepoCache(c cache.Cache[string], include, exclude []
|
|||
for _, eg := range exclude {
|
||||
g, err := glob.Compile(eg)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
|
||||
ctx.Logger().V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
|
||||
continue
|
||||
}
|
||||
excludeGlobs = append(excludeGlobs, g)
|
||||
|
@ -180,8 +179,6 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
|
|||
return err
|
||||
}
|
||||
|
||||
s.log = aCtx.Logger()
|
||||
|
||||
s.name = name
|
||||
s.sourceID = sourceID
|
||||
s.jobID = jobID
|
||||
|
@ -208,7 +205,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
|
|||
}
|
||||
s.memberCache = make(map[string]struct{})
|
||||
|
||||
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](),
|
||||
s.filteredRepoCache = s.newFilteredRepoCache(aCtx,
|
||||
memory.New[string](),
|
||||
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
|
||||
s.conn.GetIgnoreRepos(),
|
||||
)
|
||||
|
@ -360,7 +358,7 @@ RepoLoop:
|
|||
// Normalize the URL to the Gist's pull URL.
|
||||
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
|
||||
repo = gist.GetGitPullURL()
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(repoCtx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -374,7 +372,7 @@ RepoLoop:
|
|||
// Cache repository info.
|
||||
for {
|
||||
ghRepo, _, err := s.connector.APIClient().Repositories.Get(repoCtx, urlParts[1], urlParts[2])
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(repoCtx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -389,8 +387,7 @@ RepoLoop:
|
|||
s.repos = append(s.repos, repo)
|
||||
}
|
||||
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
|
||||
s.log.Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
|
||||
|
||||
ctx.Logger().Info("Completed enumeration", "num_repos", len(s.repos), "num_orgs", s.orgsCache.Count(), "num_members", len(s.memberCache))
|
||||
// We must sort the repos so we can resume later if necessary.
|
||||
sort.Strings(s.repos)
|
||||
return nil
|
||||
|
@ -417,7 +414,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error {
|
|||
|
||||
func (s *Source) enumerateUnauthenticated(ctx context.Context) {
|
||||
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
|
||||
s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
|
||||
ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
|
||||
}
|
||||
|
||||
for _, org := range s.orgsCache.Keys() {
|
||||
|
@ -441,7 +438,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
|
|||
var err error
|
||||
for {
|
||||
ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "")
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -454,10 +451,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
|
|||
if !specificScope {
|
||||
// Enumerate the user's orgs and repos if none were specified.
|
||||
if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil {
|
||||
s.log.Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
|
||||
ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
|
||||
}
|
||||
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil {
|
||||
s.log.Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
|
||||
ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
|
||||
}
|
||||
|
||||
if isGithubEnterprise {
|
||||
|
@ -486,7 +483,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
|
|||
}
|
||||
|
||||
if s.conn.ScanUsers && len(s.memberCache) > 0 {
|
||||
s.log.Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
|
||||
ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
|
||||
s.addReposForMembers(ctx)
|
||||
}
|
||||
}
|
||||
|
@ -507,9 +504,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.log.Info("Scanning repos", "org_members", len(s.memberCache))
|
||||
ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache))
|
||||
for member := range s.memberCache {
|
||||
logger := s.log.WithValues("member", member)
|
||||
logger := ctx.Logger().WithValues("member", member)
|
||||
if err := s.addUserGistsToCache(ctx, member); err != nil {
|
||||
logger.Error(err, "error fetching gists by user")
|
||||
}
|
||||
|
@ -536,7 +533,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
|
|||
func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||
var scannedCount uint64 = 1
|
||||
|
||||
s.log.V(2).Info("Found repos to scan", "count", len(s.repos))
|
||||
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
|
||||
|
||||
// If there is resume information available, limit this scan to only the repos that still need scanning.
|
||||
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
|
||||
|
@ -574,7 +571,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
|
|||
if !ok {
|
||||
// This should never happen.
|
||||
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
|
||||
s.log.Error(err, "failed to scan repository")
|
||||
ctx.Logger().Error(err, "failed to scan repository")
|
||||
return nil
|
||||
}
|
||||
repoCtx := context.WithValues(ctx, "repo", repoURL)
|
||||
|
@ -618,7 +615,7 @@ func (s *Source) scan(ctx context.Context, chunksChan chan *sources.Chunk) error
|
|||
|
||||
_ = s.jobPool.Wait()
|
||||
if scanErrs.Count() > 0 {
|
||||
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
|
||||
ctx.Logger().Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
|
||||
}
|
||||
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
|
||||
|
||||
|
@ -666,7 +663,7 @@ var (
|
|||
// Authenticated users have a rate limit of 5,000 requests per hour,
|
||||
// however, certain actions are subject to a stricter "secondary" limit.
|
||||
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
|
||||
func (s *Source) handleRateLimit(errIn error) bool {
|
||||
func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
|
||||
if errIn == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -705,12 +702,12 @@ func (s *Source) handleRateLimit(errIn error) bool {
|
|||
if retryAfter > 0 {
|
||||
retryAfter = retryAfter + jitter
|
||||
rateLimitResumeTime = now.Add(retryAfter)
|
||||
s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
|
||||
ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
|
||||
} else {
|
||||
retryAfter = (5 * time.Minute) + jitter
|
||||
rateLimitResumeTime = now.Add(retryAfter)
|
||||
// TODO: Use exponential backoff instead of static retry time.
|
||||
s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
|
||||
ctx.Logger().Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
rateLimitMu.Unlock()
|
||||
|
@ -725,13 +722,13 @@ func (s *Source) handleRateLimit(errIn error) bool {
|
|||
}
|
||||
|
||||
func (s *Source) addReposForMembers(ctx context.Context) {
|
||||
s.log.Info("Fetching repos from members", "members", len(s.memberCache))
|
||||
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
|
||||
for member := range s.memberCache {
|
||||
if err := s.addUserGistsToCache(ctx, member); err != nil {
|
||||
s.log.Info("Unable to fetch gists by user", "user", member, "error", err)
|
||||
ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err)
|
||||
}
|
||||
if err := s.getReposByUser(ctx, member); err != nil {
|
||||
s.log.Info("Unable to fetch repos by user", "user", member, "error", err)
|
||||
ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -740,10 +737,11 @@ func (s *Source) addReposForMembers(ctx context.Context) {
|
|||
// and adds them to the filteredRepoCache.
|
||||
func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
|
||||
gistOpts := &github.GistListOptions{}
|
||||
logger := s.log.WithValues("user", user)
|
||||
logger := ctx.Logger().WithValues("user", user)
|
||||
|
||||
for {
|
||||
gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -788,7 +786,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
|
|||
}
|
||||
|
||||
func (s *Source) addAllVisibleOrgs(ctx context.Context) {
|
||||
s.log.V(2).Info("enumerating all visible organizations on GHE")
|
||||
ctx.Logger().V(2).Info("enumerating all visible organizations on GHE")
|
||||
// Enumeration on this endpoint does not use pages it uses a since ID.
|
||||
// The endpoint will return organizations with an ID greater than the given since ID.
|
||||
// Empty org response is our cue to break the enumeration loop.
|
||||
|
@ -800,11 +798,11 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
|
|||
}
|
||||
for {
|
||||
orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
s.log.Error(err, "could not list all organizations")
|
||||
ctx.Logger().Error(err, "could not list all organizations")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -813,7 +811,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
|
|||
}
|
||||
|
||||
lastOrgID := *orgs[len(orgs)-1].ID
|
||||
s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
|
||||
ctx.Logger().V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
|
||||
orgOpts.Since = lastOrgID
|
||||
|
||||
for _, org := range orgs {
|
||||
|
@ -827,7 +825,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
|
|||
continue
|
||||
}
|
||||
s.orgsCache.Set(name, name)
|
||||
s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
|
||||
ctx.Logger().V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -836,10 +834,10 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
|
|||
orgOpts := &github.ListOptions{
|
||||
PerPage: defaultPagination,
|
||||
}
|
||||
logger := s.log.WithValues("user", user)
|
||||
logger := ctx.Logger().WithValues("user", user)
|
||||
for {
|
||||
orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -869,10 +867,10 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
|
|||
},
|
||||
}
|
||||
|
||||
logger := s.log.WithValues("org", org)
|
||||
logger := ctx.Logger().WithValues("org", org)
|
||||
for {
|
||||
members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil || len(members) == 0 {
|
||||
|
@ -994,7 +992,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
|
|||
}
|
||||
for {
|
||||
comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -1107,7 +1105,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksCha
|
|||
|
||||
for {
|
||||
issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -1179,7 +1177,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, ch
|
|||
|
||||
for {
|
||||
issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -1244,7 +1242,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan c
|
|||
|
||||
for {
|
||||
prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -1276,7 +1274,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunk
|
|||
|
||||
for {
|
||||
prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-github/v63/github"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -369,7 +368,8 @@ func TestNormalizeRepos(t *testing.T) {
|
|||
|
||||
func TestHandleRateLimit(t *testing.T) {
|
||||
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
|
||||
assert.False(t, s.handleRateLimit(nil))
|
||||
ctx := context.Background()
|
||||
assert.False(t, s.handleRateLimit(ctx, nil))
|
||||
|
||||
// Request
|
||||
reqUrl, _ := url.Parse("https://github.com/trufflesecurity/trufflehog")
|
||||
|
@ -400,7 +400,7 @@ func TestHandleRateLimit(t *testing.T) {
|
|||
Message: "Too Many Requests",
|
||||
}
|
||||
|
||||
assert.True(t, s.handleRateLimit(err))
|
||||
assert.True(t, s.handleRateLimit(ctx, err))
|
||||
}
|
||||
|
||||
func TestEnumerateUnauthenticated(t *testing.T) {
|
||||
|
@ -721,7 +721,6 @@ func Test_setProgressCompleteWithRepo_resumeInfo(t *testing.T) {
|
|||
|
||||
s := &Source{
|
||||
repos: []string{},
|
||||
log: logr.Discard(),
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -772,7 +771,6 @@ func Test_setProgressCompleteWithRepo_Progress(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
s := &Source{
|
||||
repos: tt.repos,
|
||||
log: logr.Discard(),
|
||||
}
|
||||
|
||||
s.setProgressCompleteWithRepo(tt.index, tt.offset, "")
|
||||
|
|
|
@ -181,7 +181,7 @@ func isGitHub404Error(err error) bool {
|
|||
}
|
||||
|
||||
func (s *Source) processRepos(ctx context.Context, target string, listRepos repoLister, listOpts repoListOptions) error {
|
||||
logger := s.log.WithValues("target", target)
|
||||
logger := ctx.Logger().WithValues("target", target)
|
||||
opts := listOpts.getListOptions()
|
||||
|
||||
var (
|
||||
|
@ -191,14 +191,14 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
|
|||
|
||||
for {
|
||||
someRepos, res, err := listRepos(ctx, target, listOpts)
|
||||
if s.handleRateLimit(err) {
|
||||
if s.handleRateLimit(ctx, err) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage)
|
||||
ctx.Logger().V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage)
|
||||
for _, r := range someRepos {
|
||||
if r.GetFork() {
|
||||
if !s.conn.IncludeForks {
|
||||
|
|
Loading…
Reference in a new issue