Fix handling of GitHub ratelimit information (#2041)

This is a follow-up to #1912, which used the headers from the response to determine rate-limiting information, instead of using the values from RateLimitError.Rate. Although that logic seemed solid, I discovered that it did not work in some circumstances. This lead to the "unexpected" path more often than intended, and periodic instances where requests would be made before the ratelimit was refreshed.
This commit is contained in:
Richard Gomez 2024-02-07 09:11:12 -05:00 committed by GitHub
parent 7b492a690a
commit b3ff12d1e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 97 additions and 98 deletions

3
go.mod
View file

@ -43,7 +43,7 @@ require (
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/go-cmp v0.6.0
github.com/google/go-containerregistry v0.17.0
github.com/google/go-github/v42 v42.0.0
github.com/google/go-github/v57 v57.0.0
github.com/google/uuid v1.5.0
github.com/googleapis/gax-go/v2 v2.12.0
github.com/h2non/filetype v1.1.3
@ -174,7 +174,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v23.1.21+incompatible // indirect
github.com/google/go-github/v57 v57.0.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.7 // indirect

2
go.sum
View file

@ -366,8 +366,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-containerregistry v0.17.0 h1:5p+zYs/R4VGHkhyvgWurWrpJ2hW4Vv9fQI+GzdcwXLk=
github.com/google/go-containerregistry v0.17.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ=
github.com/google/go-github/v42 v42.0.0 h1:YNT0FwjPrEysRkLIiKuEfSvBPCGKphW5aS5PxwaoLec=
github.com/google/go-github/v42 v42.0.0/go.mod h1:jgg/jvyI0YlDOM1/ps6XYh04HNQ3vKf0CVko62/EhRg=
github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs=
github.com/google/go-github/v57 v57.0.0/go.mod h1:s0omdnye0hvK/ecLvpsGfJMiRt85PimQh4oygmLIxHw=
github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=

View file

@ -18,13 +18,14 @@ import (
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
"github.com/google/go-github/v42/github"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"github.com/google/go-github/v57/github"
"golang.org/x/oauth2"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"

View file

@ -14,10 +14,12 @@ import (
"sync/atomic"
"time"
"golang.org/x/exp/rand"
"github.com/bradleyfalzon/ghinstallation/v2"
"github.com/go-logr/logr"
"github.com/gobwas/glob"
"github.com/google/go-github/v42/github"
"github.com/google/go-github/v57/github"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
@ -388,7 +390,6 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
return
}
var resp *github.Response
urlPathParts := strings.Split(u.Path, "/")
switch len(urlPathParts) {
case 2:
@ -397,8 +398,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
repoName := urlPathParts[1]
repoName = strings.TrimSuffix(repoName, ".git")
for {
gist, resp, err = s.apiClient.Gists.Get(ctx, repoName)
if !s.handleRateLimit(err, resp) {
gist, _, err = s.apiClient.Gists.Get(ctx, repoName)
if !s.handleRateLimit(err) {
break
}
}
@ -415,8 +416,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
repoName := urlPathParts[2]
repoName = strings.TrimSuffix(repoName, ".git")
for {
repo, resp, err = s.apiClient.Repositories.Get(ctx, owner, repoName)
if !s.handleRateLimit(err, resp) {
repo, _, err = s.apiClient.Repositories.Get(ctx, owner, repoName)
if !s.handleRateLimit(err) {
break
}
}
@ -584,13 +585,12 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
var (
ghUser *github.User
resp *github.Response
)
ctx.Logger().V(1).Info("Enumerating with token", "endpoint", apiEndpoint)
for {
ghUser, resp, err = s.apiClient.Users.Get(ctx, "")
if handled := s.handleRateLimit(err, resp); handled {
ghUser, _, err = s.apiClient.Users.Get(ctx, "")
if s.handleRateLimit(err) {
continue
}
if err != nil {
@ -696,7 +696,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
// Does this need to be separate from |s.httpClient|?
instHTTPClient := common.RetryableHttpClientTimeout(60)
instHTTPClient.Transport = appItr
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHTTPClient)
installationClient, err = github.NewClient(instHTTPClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
if err != nil {
return nil, err
}
@ -713,7 +713,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
itr.BaseURL = apiEndpoint
s.httpClient.Transport = itr
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient)
s.apiClient, err = github.NewClient(s.httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
if err != nil {
return nil, err
}
@ -753,7 +753,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
return github.NewClient(httpClient), nil
}
return github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient)
return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
}
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
@ -869,53 +869,70 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, re
return duration, nil
}
// handleRateLimit returns true if a rate limit was handled
// Unauthenticated access to most github endpoints has a rate limit of 60 requests per hour.
// This will likely only be exhausted if many users/orgs are scanned without auth
func (s *Source) handleRateLimit(errIn error, res *github.Response) bool {
var (
knownWait = true
remaining = 0
retryAfter time.Duration
)
var (
rateLimitMu sync.RWMutex
rateLimitResumeTime time.Time
)
// GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors.
var rateLimit *github.RateLimitError
var abuseLimit *github.AbuseRateLimitError
if errors.As(errIn, &rateLimit) {
// Do nothing
} else if errors.As(errIn, &abuseLimit) {
retryAfter = abuseLimit.GetRetryAfter()
} else {
// handleRateLimit returns true if a rate limit was handled
//
// Unauthenticated users have a rate limit of 60 requests per hour.
// Authenticated users have a rate limit of 5,000 requests per hour,
// however, certain actions are subject to a stricter "secondary" limit.
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
func (s *Source) handleRateLimit(errIn error) bool {
if errIn == nil {
return false
}
githubNumRateLimitEncountered.WithLabelValues(s.name).Inc()
// Parse retry information from response headers, unless a Retry-After value was already provided.
// https://docs.github.com/en/rest/overview/resources-in-the-rest-api#exceeding-the-rate-limit
if retryAfter <= 0 && res != nil {
var err error
remaining, err = strconv.Atoi(res.Header.Get("x-ratelimit-remaining"))
if err != nil {
knownWait = false
rateLimitMu.RLock()
resumeTime := rateLimitResumeTime
rateLimitMu.RUnlock()
var retryAfter time.Duration
if resumeTime.IsZero() || time.Now().After(resumeTime) {
rateLimitMu.Lock()
var (
now = time.Now()
// GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors.
limitType string
rateLimit *github.RateLimitError
abuseLimit *github.AbuseRateLimitError
)
if errors.As(errIn, &rateLimit) {
limitType = "primary"
rate := rateLimit.Rate
if rate.Remaining == 0 { // TODO: Will we ever receive a |RateLimitError| when remaining > 0?
retryAfter = rate.Reset.Sub(now)
}
} else if errors.As(errIn, &abuseLimit) {
limitType = "secondary"
retryAfter = abuseLimit.GetRetryAfter()
} else {
rateLimitMu.Unlock()
return false
}
resetTime, err := strconv.Atoi(res.Header.Get("x-ratelimit-reset"))
if err != nil || resetTime == 0 {
knownWait = false
} else if resetTime > 0 {
retryAfter = time.Duration(int64(resetTime)-time.Now().Unix()) * time.Second
jitter := time.Duration(rand.Intn(10)+1) * time.Second
if retryAfter > 0 {
retryAfter = retryAfter + jitter
rateLimitResumeTime = now.Add(retryAfter)
s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
} else {
retryAfter = (5 * time.Minute) + jitter
rateLimitResumeTime = now.Add(retryAfter)
// TODO: Use exponential backoff instead of static retry time.
s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
}
}
resumeTime := time.Now().Add(retryAfter).String()
if knownWait && remaining == 0 && retryAfter > 0 {
s.log.V(2).Info("rate limited", "retry_after", retryAfter.String(), "resume_time", resumeTime)
rateLimitMu.Unlock()
} else {
// TODO: Use exponential backoff instead of static retry time.
retryAfter = time.Minute * 5
s.log.V(2).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", resumeTime)
retryAfter = time.Until(resumeTime)
}
githubNumRateLimitEncountered.WithLabelValues(s.name).Inc()
time.Sleep(retryAfter)
githubSecondsSpentRateLimited.WithLabelValues(s.name).Add(retryAfter.Seconds())
return true
@ -940,10 +957,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
logger := s.log.WithValues("user", user)
for {
gists, res, err := s.apiClient.Gists.List(ctx, user, gistOpts)
if err == nil {
res.Body.Close()
}
if handled := s.handleRateLimit(err, res); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil {
@ -996,11 +1010,8 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
},
}
for {
orgs, resp, err := s.apiClient.Organizations.ListAll(ctx, orgOpts)
if err == nil {
resp.Body.Close()
}
if handled := s.handleRateLimit(err, resp); handled {
orgs, _, err := s.apiClient.Organizations.ListAll(ctx, orgOpts)
if s.handleRateLimit(err) {
continue
}
if err != nil {
@ -1037,10 +1048,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
logger := s.log.WithValues("user", user)
for {
orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts)
if err == nil {
resp.Body.Close()
}
if handled := s.handleRateLimit(err, resp); handled {
if handled := s.handleRateLimit(err); handled {
continue
}
if err != nil {
@ -1075,10 +1083,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
logger := s.log.WithValues("org", org)
for {
members, res, err := s.apiClient.Organizations.ListMembers(ctx, org, opts)
if err == nil {
defer res.Body.Close()
}
if handled := s.handleRateLimit(err, res); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil || len(members) == 0 {
@ -1150,8 +1155,8 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm
Page: initialPage,
}
for {
comments, resp, err := s.apiClient.Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err, resp) {
comments, _, err := s.apiClient.Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err) {
break
}
if err != nil {
@ -1254,8 +1259,8 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch
}
for {
issues, resp, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts)
if s.handleRateLimit(err, resp) {
issues, _, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts)
if s.handleRateLimit(err) {
break
}
@ -1287,8 +1292,8 @@ func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunks
}
for {
issueComments, resp, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts)
if s.handleRateLimit(err, resp) {
issueComments, _, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts)
if s.handleRateLimit(err) {
break
}
@ -1321,8 +1326,8 @@ func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan
}
for {
prs, resp, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts)
if s.handleRateLimit(err, resp) {
prs, _, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts)
if s.handleRateLimit(err) {
break
}
@ -1354,8 +1359,8 @@ func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksCha
}
for {
prComments, resp, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts)
if s.handleRateLimit(err, resp) {
prComments, _, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts)
if s.handleRateLimit(err) {
break
}

View file

@ -17,7 +17,7 @@ import (
"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/go-github/v42/github"
"github.com/google/go-github/v57/github"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
@ -330,13 +330,13 @@ func TestNormalizeRepos(t *testing.T) {
func TestHandleRateLimit(t *testing.T) {
s := initTestSource(nil)
assert.False(t, s.handleRateLimit(nil, nil))
assert.False(t, s.handleRateLimit(nil))
err := &github.RateLimitError{}
res := &github.Response{Response: &http.Response{Header: make(http.Header)}}
res.Header.Set("x-ratelimit-remaining", "0")
res.Header.Set("x-ratelimit-reset", strconv.FormatInt(time.Now().Unix()+1, 10))
assert.True(t, s.handleRateLimit(err, res))
assert.True(t, s.handleRateLimit(err))
}
func TestEnumerateUnauthenticated(t *testing.T) {

View file

@ -7,7 +7,7 @@ import (
"strings"
gogit "github.com/go-git/go-git/v5"
"github.com/google/go-github/v42/github"
"github.com/google/go-github/v57/github"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
@ -98,12 +98,11 @@ func (s *Source) userAndToken(ctx context.Context, installationClient *github.Cl
case *sourcespb.GitHub_Token:
var (
ghUser *github.User
resp *github.Response
err error
)
for {
ghUser, resp, err = s.apiClient.Users.Get(ctx, "")
if handled := s.handleRateLimit(err, resp); handled {
ghUser, _, err = s.apiClient.Users.Get(ctx, "")
if s.handleRateLimit(err) {
continue
}
if err != nil {
@ -150,7 +149,7 @@ func (s *Source) appListReposWrapper(ctx context.Context, _ string, opts repoLis
}
type userListOptions struct {
github.RepositoryListOptions
github.RepositoryListByUserOptions
}
func (u *userListOptions) getListOptions() *github.ListOptions {
@ -159,7 +158,7 @@ func (u *userListOptions) getListOptions() *github.ListOptions {
func (s *Source) getReposByUser(ctx context.Context, user string) error {
return s.processRepos(ctx, user, s.userListReposWrapper, &userListOptions{
RepositoryListOptions: github.RepositoryListOptions{
RepositoryListByUserOptions: github.RepositoryListByUserOptions{
ListOptions: github.ListOptions{
PerPage: defaultPagination,
},
@ -168,7 +167,7 @@ func (s *Source) getReposByUser(ctx context.Context, user string) error {
}
func (s *Source) userListReposWrapper(ctx context.Context, user string, opts repoListOptions) ([]*github.Repository, *github.Response, error) {
return s.apiClient.Repositories.List(ctx, user, &opts.(*userListOptions).RepositoryListOptions)
return s.apiClient.Repositories.ListByUser(ctx, user, &opts.(*userListOptions).RepositoryListByUserOptions)
}
type orgListOptions struct {
@ -204,10 +203,7 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
for {
someRepos, res, err := listRepos(ctx, target, listOpts)
if err == nil {
res.Body.Close()
}
if handled := s.handleRateLimit(err, res); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil {
@ -287,8 +283,8 @@ type commitQuery struct {
// getDiffForFileInCommit retrieves the diff for a specified file in a commit.
// If the file or its diff is not found, it returns an error.
func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) {
commit, resp, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil)
if handled := s.handleRateLimit(err, resp); handled {
commit, _, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil)
if s.handleRateLimit(err) {
return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err)
}
if err != nil {