Allow github to resume from encoded resume info (#601)

This commit is contained in:
trufflesteeeve 2022-06-06 12:08:57 -04:00 committed by GitHub
parent 59fc54b94a
commit fd79a367f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 316 additions and 21 deletions

View file

@ -3,11 +3,11 @@ package github
import (
"context"
"fmt"
"math/rand"
"net/http"
"os"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@ -36,21 +36,23 @@ import (
)
type Source struct {
name string
sourceID int64
jobID int64
verify bool
repos []string
orgs []string
members []string
git *git.Git
httpClient *http.Client
aCtx context.Context
name string
sourceID int64
jobID int64
verify bool
repos []string
orgs []string
members []string
git *git.Git
httpClient *http.Client
aCtx context.Context
log *log.Entry
token string
conn *sourcespb.GitHub
jobSem *semaphore.Weighted
resumeInfoSlice []string
resumeInfoMutex *sync.Mutex
sources.Progress
log *log.Entry
token string
conn *sourcespb.GitHub
jobSem *semaphore.Weighted
}
// Ensure the Source satisfies the interface at compile time
@ -323,11 +325,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
s.normalizeRepos(ctx, apiClient)
if _, ok := os.LookupEnv("DO_NOT_RANDOMIZE"); !ok {
// Randomize channel scan order on each scan
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(s.repos), func(i, j int) { s.repos[i], s.repos[j] = s.repos[j], s.repos[i] })
}
// We must sort the repos so we can resume later if necessary.
sort.Strings(s.repos)
return s.scan(ctx, installationClient, chunksChan)
}
@ -347,6 +346,9 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
}
}
// If there is resume information available, limit this scan to only the repos that still need scanning.
progressIndexOffset := s.filterReposToResume(s.GetProgress().EncodedResumeInfo)
for i, repoURL := range s.repos {
if err := s.jobSem.Acquire(ctx, 1); err != nil {
// Acquire blocks until it can acquire the semaphore or returns an
@ -360,7 +362,9 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
defer s.jobSem.Release(1)
defer wg.Done()
s.SetProgressComplete(i, len(s.repos), fmt.Sprintf("Repo: %s", repoURL), "")
s.setProgressCompleteWithRepo(i+progressIndexOffset, repoURL)
// Ensure the repo is removed from the resume info after being scanned.
defer s.removeRepoFromResumeInfo(repoURL)
if !strings.HasSuffix(repoURL, ".git") {
return
@ -728,3 +732,102 @@ func (s *Source) normalizeRepos(ctx context.Context, apiClient *github.Client) {
s.repos = append(s.repos, key)
}
}
// setProgressCompleteWithRepo calls the s.SetProgressComplete after safely setting up the encoded resume info string.
func (s *Source) setProgressCompleteWithRepo(index int, repoURL string) {
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
// Add the repoURL to the resume info slice.
s.resumeInfoSlice = append(s.resumeInfoSlice, repoURL)
sort.Strings(s.resumeInfoSlice)
// Make the resume info string from the slice.
encodedResumeInfo := s.encodeResumeInfo()
s.SetProgressComplete(index, len(s.repos), fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo)
}
// removeRepoFromResumeInfo removes the repoURL from the resume info.
func (s *Source) removeRepoFromResumeInfo(repoURL string) {
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
index := -1
for i, repo := range s.resumeInfoSlice {
if repoURL == repo {
index = i
}
}
if index == -1 {
// We should never be able to be here. But if we are, it means the resume info never had the repo added.
// So log the error and do nothing.
s.log.Errorf("repoURL (%q) not found in list of encode resume info: %q", repoURL, s.EncodedResumeInfo)
return
}
// This removes the element at the given index.
s.resumeInfoSlice = append(s.resumeInfoSlice[:index], s.resumeInfoSlice[index+1:]...)
}
func (s *Source) encodeResumeInfo() string {
return strings.Join(s.resumeInfoSlice, "\t")
}
func (s *Source) decodeResumeInfo(resumeInfo string) {
// strings.Split will, for an empty string, return []string{""},
// which is an element, where as when there is no resume info we want an empty slice.
if resumeInfo == "" {
return
}
s.resumeInfoSlice = strings.Split(resumeInfo, "\t")
}
// filterReposToResume filters the existing repos down to those that are included in the encoded resume info.
// It also returns the difference between the original length of the repos and the new length to use for progress reporting.
// It is required that both the resumeInfo repos and the existing repos in s.repos are sorted.
func (s *Source) filterReposToResume(resumeInfo string) int {
if resumeInfo == "" {
return 0
}
s.resumeInfoMutex.Lock()
defer s.resumeInfoMutex.Unlock()
s.decodeResumeInfo(resumeInfo)
// Because this scanner is multithreaded, it is possible that we have scanned a range of repositories
// with some gaps of unlisted but completed repositories in between the ones in resumeInfo.
// So we know repositories that have not finished scanning are the ones included in the resumeInfo,
// and those that come after the last repository in the resumeInfo.
// However, it is possible that a resumed scan does not include all or even any of the repos within the resumeInfo.
// In this case, we must ensure we still scan all repos that come after the last found repo in the list.
reposToScan := []string{}
lastFoundRepoIndex := -1
resumeRepoIndex := 0
for i, repoURL := range s.repos {
// If the repoURL is bigger than what we're looking for, move to the next one.
if repoURL > s.resumeInfoSlice[resumeRepoIndex] {
resumeRepoIndex++
}
// If we've found all of our repositories end the filter.
if resumeRepoIndex == len(s.resumeInfoSlice) {
break
}
// If the repoURL is the one we're looking for, add it and update the lastFoundRepoIndex.
if repoURL == s.resumeInfoSlice[resumeRepoIndex] {
lastFoundRepoIndex = i
reposToScan = append(reposToScan, repoURL)
}
}
// Append all repos after the last one we've found.
reposToScan = append(reposToScan, s.repos[lastFoundRepoIndex+1:]...)
progressOffsetCount := len(s.repos) - len(reposToScan)
s.repos = reposToScan
return progressOffsetCount
}

View file

@ -8,13 +8,17 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net/http"
"reflect"
"sort"
"strconv"
"sync"
"testing"
"time"
"github.com/google/go-github/v42/github"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
@ -336,3 +340,191 @@ func TestEnumerateWithApp(t *testing.T) {
assert.True(t, gock.IsDone())
}
// This only tests the resume info slice portion of setProgressCompleteWithRepo.
func Test_setProgressCompleteWithRepo(t *testing.T) {
tests := []struct {
startingResumeInfoSlice []string
repoURL string
wantResumeInfoSlice []string
}{
{
startingResumeInfoSlice: []string{},
repoURL: "a",
wantResumeInfoSlice: []string{"a"},
},
{
startingResumeInfoSlice: []string{"b"},
repoURL: "a",
wantResumeInfoSlice: []string{"a", "b"},
},
}
logger := logrus.New()
logger.Out = io.Discard
s := &Source{
repos: []string{},
log: logger.WithField("no", "output"),
resumeInfoMutex: &sync.Mutex{},
}
for _, tt := range tests {
s.resumeInfoSlice = tt.startingResumeInfoSlice
s.setProgressCompleteWithRepo(0, tt.repoURL)
if !reflect.DeepEqual(s.resumeInfoSlice, tt.wantResumeInfoSlice) {
t.Errorf("s.setProgressCompleteWithRepo() got: %v, want: %v", s.resumeInfoSlice, tt.wantResumeInfoSlice)
}
}
}
func Test_removeRepoFromResumeInfo(t *testing.T) {
tests := []struct {
startingResumeInfoSlice []string
repoURL string
wantResumeInfoSlice []string
}{
{
startingResumeInfoSlice: []string{"a", "b", "c"},
repoURL: "a",
wantResumeInfoSlice: []string{"b", "c"},
},
{
startingResumeInfoSlice: []string{"a", "b", "c"},
repoURL: "b",
wantResumeInfoSlice: []string{"a", "c"},
},
{ // This is the probably can't happen case of a repo not in the list.
startingResumeInfoSlice: []string{"a", "b", "c"},
repoURL: "not in the list",
wantResumeInfoSlice: []string{"a", "b", "c"},
},
}
logger := logrus.New()
logger.Out = io.Discard
s := &Source{
repos: []string{},
log: logger.WithField("no", "output"),
resumeInfoMutex: &sync.Mutex{},
}
for _, tt := range tests {
s.resumeInfoSlice = tt.startingResumeInfoSlice
s.removeRepoFromResumeInfo(tt.repoURL)
if !reflect.DeepEqual(s.resumeInfoSlice, tt.wantResumeInfoSlice) {
t.Errorf("s.removeRepoFromResumeInfo() got: %v, want: %v", s.resumeInfoSlice, tt.wantResumeInfoSlice)
}
}
}
func Test_encodeResumeInfo(t *testing.T) {
tests := []struct {
startingResumeInfoSlice []string
wantEncodedResumeInfo string
}{
{
startingResumeInfoSlice: []string{"a", "b", "c"},
wantEncodedResumeInfo: "a\tb\tc",
},
{
startingResumeInfoSlice: []string{},
wantEncodedResumeInfo: "",
},
}
logger := logrus.New()
logger.Out = io.Discard
s := &Source{
repos: []string{},
log: logger.WithField("no", "output"),
resumeInfoMutex: &sync.Mutex{},
}
for _, tt := range tests {
s.resumeInfoSlice = tt.startingResumeInfoSlice
gotEncodedResumeInfo := s.encodeResumeInfo()
if gotEncodedResumeInfo != tt.wantEncodedResumeInfo {
t.Errorf("s.encodeResumeInfo() got: %q, want: %q", gotEncodedResumeInfo, tt.wantEncodedResumeInfo)
}
}
}
func Test_decodeResumeInfo(t *testing.T) {
tests := []struct {
resumeInfo string
wantResumeInfoSlice []string
}{
{
resumeInfo: "a\tb\tc",
wantResumeInfoSlice: []string{"a", "b", "c"},
},
{
resumeInfo: "",
wantResumeInfoSlice: nil,
},
}
for _, tt := range tests {
s := &Source{}
s.decodeResumeInfo(tt.resumeInfo)
if !reflect.DeepEqual(s.resumeInfoSlice, tt.wantResumeInfoSlice) {
t.Errorf("s.decodeResumeInfo() got: %v, want: %v", s.resumeInfoSlice, tt.wantResumeInfoSlice)
}
}
}
func Test_filterReposToResume(t *testing.T) {
startingRepos := []string{"a", "b", "c", "d", "e", "f", "g"}
tests := map[string]struct {
resumeInfo string
wantProgressOffsetCount int
wantReposToScan []string
}{
"blank resume info": {
resumeInfo: "",
wantProgressOffsetCount: 0,
wantReposToScan: startingRepos,
},
"starting repos": {
resumeInfo: "a\tb",
wantProgressOffsetCount: 0,
wantReposToScan: startingRepos,
},
"early contiguous repos": {
resumeInfo: "b\tc",
wantProgressOffsetCount: 1,
wantReposToScan: []string{"b", "c", "d", "e", "f", "g"},
},
"non-contiguous repos": {
resumeInfo: "b\te",
wantProgressOffsetCount: 3,
wantReposToScan: []string{"b", "e", "f", "g"},
},
"no repos found in the repo list": {
resumeInfo: "not\tthere",
wantProgressOffsetCount: 0,
wantReposToScan: startingRepos,
},
"only some repos in the list": {
resumeInfo: "c\tnot\tthere",
wantProgressOffsetCount: 2,
wantReposToScan: []string{"c", "d", "e", "f", "g"},
},
}
for name, tt := range tests {
s := &Source{
repos: startingRepos,
resumeInfoMutex: &sync.Mutex{},
}
gotProgressOffsetCount := s.filterReposToResume(tt.resumeInfo)
if gotProgressOffsetCount != tt.wantProgressOffsetCount {
t.Errorf("s.filterReposToResume() name: %q got: %d, want: %d", name, gotProgressOffsetCount, tt.wantProgressOffsetCount)
}
if !reflect.DeepEqual(s.repos, tt.wantReposToScan) {
t.Errorf("s.filterReposToResume() name: %q got: %v, want: %v", name, s.repos, tt.wantReposToScan)
}
}
}