Instrument GitHub source with a UnitReporter (#3284)

* Fix GitHub integration test

* Instrument GitHub source with a UnitReporter

The reporter is currently unused, but is the first step to support
scanning while enumerating.

* Update GitHub unit tests
This commit is contained in:
Miccah 2024-09-12 10:28:37 -07:00 committed by GitHub
parent 0cb872307c
commit e89190f3ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 93 additions and 47 deletions

View file

@ -75,6 +75,32 @@ type Source struct {
sources.CommonSourceUnitUnmarshaller
}
// --------------------------------------------------------------------------------
// RepoUnit and GistUnit are implementations of SourceUnit used during
// enumeration. The different types aren't strictly necessary, but are a bit
// more explicit and allow type checking/safety.
var _ sources.SourceUnit = (*RepoUnit)(nil)
var _ sources.SourceUnit = (*GistUnit)(nil)
type RepoUnit struct {
name string
url string
}
func (r RepoUnit) SourceUnitID() (string, sources.SourceUnitKind) { return r.url, "repo" }
func (r RepoUnit) Display() string { return r.name }
type GistUnit struct {
name string
url string
}
func (g GistUnit) SourceUnitID() (string, sources.SourceUnitKind) { return g.url, "gist" }
func (g GistUnit) Display() string { return g.name }
// --------------------------------------------------------------------------------
// WithCustomContentWriter sets the useCustomContentWriter flag on the source.
func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true }
@ -313,25 +339,30 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar
}
func (s *Source) enumerate(ctx context.Context) error {
// Create a reporter that does nothing for now.
noopReporter := sources.VisitorReporter{
VisitUnit: func(ctx context.Context, su sources.SourceUnit) error {
return nil
},
}
// I'm not wild about switching on the connector type here (as opposed to dispatching to the connector itself) but
// this felt like a compromise that allowed me to isolate connection logic without rewriting the entire source.
switch c := s.connector.(type) {
case *appConnector:
if err := s.enumerateWithApp(ctx, c.InstallationClient()); err != nil {
if err := s.enumerateWithApp(ctx, c.InstallationClient(), noopReporter); err != nil {
return err
}
case *basicAuthConnector:
if err := s.enumerateBasicAuth(ctx); err != nil {
if err := s.enumerateBasicAuth(ctx, noopReporter); err != nil {
return err
}
case *tokenConnector:
if err := s.enumerateWithToken(ctx, c.IsGithubEnterprise()); err != nil {
if err := s.enumerateWithToken(ctx, c.IsGithubEnterprise(), noopReporter); err != nil {
return err
}
case *unauthenticatedConnector:
s.enumerateUnauthenticated(ctx)
s.enumerateUnauthenticated(ctx, noopReporter)
}
s.repos = make([]string, 0, s.filteredRepoCache.Count())
RepoLoop:
@ -393,15 +424,17 @@ RepoLoop:
return nil
}
func (s *Source) enumerateBasicAuth(ctx context.Context) error {
func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitReporter) error {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "error fetching repos for org or user")
continue
}
// TODO: This modifies s.memberCache but it doesn't look like
// we do anything with it.
if userType == organization && s.conn.ScanUsers {
if err := s.addMembersByOrg(ctx, org); err != nil {
orgCtx.Logger().Error(err, "Unable to add members by org")
@ -412,14 +445,14 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error {
return nil
}
func (s *Source) enumerateUnauthenticated(ctx context.Context) {
func (s *Source) enumerateUnauthenticated(ctx context.Context, reporter sources.UnitReporter) {
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "error fetching repos for org or user")
continue
@ -431,7 +464,7 @@ func (s *Source) enumerateUnauthenticated(ctx context.Context) {
}
}
func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool) error {
func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool, reporter sources.UnitReporter) error {
ctx.Logger().V(1).Info("Enumerating with token")
var ghUser *github.User
@ -450,10 +483,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
specificScope := len(s.repos) > 0 || s.orgsCache.Count() > 0
if !specificScope {
// Enumerate the user's orgs and repos if none were specified.
if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil {
if err := s.getReposByUser(ctx, ghUser.GetLogin(), reporter); err != nil {
ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
}
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil {
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin(), reporter); err != nil {
ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
}
@ -469,7 +502,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
if len(s.orgsCache.Keys()) > 0 {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "Unable to fetch repos for org or user")
continue
@ -484,17 +517,17 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
if s.conn.ScanUsers && len(s.memberCache) > 0 {
ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
s.addReposForMembers(ctx)
s.addReposForMembers(ctx, reporter)
}
}
return nil
}
func (s *Source) enumerateWithApp(ctx context.Context, installationClient *github.Client) error {
func (s *Source) enumerateWithApp(ctx context.Context, installationClient *github.Client, reporter sources.UnitReporter) error {
// If no repos were provided, enumerate them.
if len(s.repos) == 0 {
if err := s.getReposByApp(ctx); err != nil {
if err := s.getReposByApp(ctx, reporter); err != nil {
return err
}
@ -505,12 +538,13 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
return err
}
ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache))
// TODO: Replace loop below with a call to s.addReposForMembers(ctx, reporter)
for member := range s.memberCache {
logger := ctx.Logger().WithValues("member", member)
if err := s.addUserGistsToCache(ctx, member); err != nil {
if err := s.addUserGistsToCache(ctx, member, reporter); err != nil {
logger.Error(err, "error fetching gists by user")
}
if err := s.getReposByUser(ctx, member); err != nil {
if err := s.getReposByUser(ctx, member, reporter); err != nil {
logger.Error(err, "error fetching repos by user")
}
}
@ -721,13 +755,13 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
return true
}
func (s *Source) addReposForMembers(ctx context.Context) {
func (s *Source) addReposForMembers(ctx context.Context, reporter sources.UnitReporter) {
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
for member := range s.memberCache {
if err := s.addUserGistsToCache(ctx, member); err != nil {
if err := s.addUserGistsToCache(ctx, member, reporter); err != nil {
ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err)
}
if err := s.getReposByUser(ctx, member); err != nil {
if err := s.getReposByUser(ctx, member, reporter); err != nil {
ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err)
}
}
@ -735,7 +769,7 @@ func (s *Source) addReposForMembers(ctx context.Context) {
// addUserGistsToCache collects all the gist urls for a given user,
// and adds them to the filteredRepoCache.
func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter sources.UnitReporter) error {
gistOpts := &github.GistListOptions{}
logger := ctx.Logger().WithValues("user", user)
@ -751,6 +785,9 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
for _, gist := range gists {
s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL())
s.cacheGistInfo(gist)
if err := reporter.UnitOk(ctx, GistUnit{name: gist.GetID(), url: gist.GetGitPullURL()}); err != nil {
return err
}
}
if res == nil || res.NextPage == 0 {

View file

@ -9,7 +9,6 @@ import (
"testing"
"time"
"github.com/go-logr/logr"
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
@ -58,12 +57,11 @@ func TestSource_Token(t *testing.T) {
s := Source{
conn: src,
log: logr.Discard(),
memberCache: map[string]struct{}{},
repoInfoCache: newRepoInfoCache(),
}
s.Init(ctx, "github integration test source", 0, 0, false, conn, 1)
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](), nil, nil)
s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil)
err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient())
assert.NoError(t, err)

View file

@ -99,7 +99,7 @@ func TestAddReposByOrg(t *testing.T) {
Repositories: nil,
IgnoreRepos: []string{"secret/super-*-repo2"},
})
err := s.getReposByOrg(context.Background(), "super-secret-org")
err := s.getReposByOrg(context.Background(), "super-secret-org", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-repo")
@ -127,7 +127,7 @@ func TestAddReposByOrg_IncludeRepos(t *testing.T) {
IncludeRepos: []string{"super-secret-org/super*"},
Organizations: []string{"super-secret-org"},
})
err := s.getReposByOrg(context.Background(), "super-secret-org")
err := s.getReposByOrg(context.Background(), "super-secret-org", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-org/super-secret-repo")
@ -155,7 +155,7 @@ func TestAddReposByUser(t *testing.T) {
},
IgnoreRepos: []string{"super-secret-user/super-secret-repo2"},
})
err := s.getReposByUser(context.Background(), "super-secret-user")
err := s.getReposByUser(context.Background(), "super-secret-user", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")
@ -173,7 +173,7 @@ func TestAddGistsByUser(t *testing.T) {
JSON([]map[string]string{{"id": "aa5a315d61ae9438b18d", "git_pull_url": "https://gist.github.com/aa5a315d61ae9438b18d.git"}})
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
err := s.addUserGistsToCache(context.Background(), "super-secret-user")
err := s.addUserGistsToCache(context.Background(), "super-secret-user", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("aa5a315d61ae9438b18d")
@ -265,7 +265,7 @@ func TestAddReposByApp(t *testing.T) {
})
s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
err := s.getReposByApp(context.Background())
err := s.getReposByApp(context.Background(), noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("ssr1")
@ -419,7 +419,7 @@ func TestEnumerateUnauthenticated(t *testing.T) {
s.orgsCache = memory.New[string]()
s.orgsCache.Set("super-secret-org", "super-secret-org")
//s.enumerateUnauthenticated(context.Background(), apiEndpoint)
s.enumerateUnauthenticated(context.Background())
s.enumerateUnauthenticated(context.Background(), noopReporter())
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-org/super-secret-repo")
assert.True(t, ok)
@ -458,7 +458,7 @@ func TestEnumerateWithToken(t *testing.T) {
Token: "token",
},
})
err := s.enumerateWithToken(context.Background(), false)
err := s.enumerateWithToken(context.Background(), false, noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")
@ -502,7 +502,7 @@ func BenchmarkEnumerateWithToken(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = s.enumerateWithToken(context.Background(), false)
_ = s.enumerateWithToken(context.Background(), false, noopReporter())
}
}
@ -660,7 +660,7 @@ func TestEnumerateWithToken_IncludeRepos(t *testing.T) {
})
s.repos = []string{"some-special-repo"}
err := s.enumerateWithToken(context.Background(), false)
err := s.enumerateWithToken(context.Background(), false, noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, len(s.repos))
assert.Equal(t, []string{"some-special-repo"}, s.repos)
@ -693,7 +693,7 @@ func TestEnumerateWithApp(t *testing.T) {
},
},
})
err := s.enumerateWithApp(context.Background(), s.connector.(*appConnector).InstallationClient())
err := s.enumerateWithApp(context.Background(), s.connector.(*appConnector).InstallationClient(), noopReporter())
assert.Nil(t, err)
assert.Equal(t, 0, len(s.repos))
assert.False(t, gock.HasUnmatchedRequest())
@ -908,3 +908,11 @@ func Test_ScanMultipleTargets_MultipleErrors(t *testing.T) {
assert.ElementsMatch(t, got, want)
}
}
func noopReporter() sources.UnitReporter {
return sources.VisitorReporter{
VisitUnit: func(context.Context, sources.SourceUnit) error {
return nil
},
}
}

View file

@ -14,6 +14,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
type repoInfoCache struct {
@ -76,8 +77,8 @@ func (s *Source) appListReposWrapper(ctx context.Context, _ string, opts repoLis
return nil, res, err
}
func (s *Source) getReposByApp(ctx context.Context) error {
return s.processRepos(ctx, "", s.appListReposWrapper, &appListOptions{
func (s *Source) getReposByApp(ctx context.Context, reporter sources.UnitReporter) error {
return s.processRepos(ctx, "", reporter, s.appListReposWrapper, &appListOptions{
ListOptions: github.ListOptions{
PerPage: defaultPagination,
},
@ -96,8 +97,8 @@ func (s *Source) userListReposWrapper(ctx context.Context, user string, opts rep
return s.connector.APIClient().Repositories.ListByUser(ctx, user, &opts.(*userListOptions).RepositoryListByUserOptions)
}
func (s *Source) getReposByUser(ctx context.Context, user string) error {
return s.processRepos(ctx, user, s.userListReposWrapper, &userListOptions{
func (s *Source) getReposByUser(ctx context.Context, user string, reporter sources.UnitReporter) error {
return s.processRepos(ctx, user, reporter, s.userListReposWrapper, &userListOptions{
RepositoryListByUserOptions: github.RepositoryListByUserOptions{
ListOptions: github.ListOptions{
PerPage: defaultPagination,
@ -119,8 +120,8 @@ func (s *Source) orgListReposWrapper(ctx context.Context, org string, opts repoL
return s.connector.APIClient().Repositories.ListByOrg(ctx, org, &opts.(*orgListOptions).RepositoryListByOrgOptions)
}
func (s *Source) getReposByOrg(ctx context.Context, org string) error {
return s.processRepos(ctx, org, s.orgListReposWrapper, &orgListOptions{
func (s *Source) getReposByOrg(ctx context.Context, org string, reporter sources.UnitReporter) error {
return s.processRepos(ctx, org, reporter, s.orgListReposWrapper, &orgListOptions{
RepositoryListByOrgOptions: github.RepositoryListByOrgOptions{
ListOptions: github.ListOptions{
PerPage: defaultPagination,
@ -145,11 +146,11 @@ const (
organization
)
func (s *Source) getReposByOrgOrUser(ctx context.Context, name string) (userType, error) {
func (s *Source) getReposByOrgOrUser(ctx context.Context, name string, reporter sources.UnitReporter) (userType, error) {
var err error
// List repositories for the organization |name|.
err = s.getReposByOrg(ctx, name)
err = s.getReposByOrg(ctx, name, reporter)
if err == nil {
return organization, nil
} else if !isGitHub404Error(err) {
@ -157,9 +158,9 @@ func (s *Source) getReposByOrgOrUser(ctx context.Context, name string) (userType
}
// List repositories for the user |name|.
err = s.getReposByUser(ctx, name)
err = s.getReposByUser(ctx, name, reporter)
if err == nil {
if err := s.addUserGistsToCache(ctx, name); err != nil {
if err := s.addUserGistsToCache(ctx, name, reporter); err != nil {
ctx.Logger().Error(err, "Unable to add user to cache")
}
return user, nil
@ -180,7 +181,7 @@ func isGitHub404Error(err error) bool {
return ghErr.Response.StatusCode == http.StatusNotFound
}
func (s *Source) processRepos(ctx context.Context, target string, listRepos repoLister, listOpts repoListOptions) error {
func (s *Source) processRepos(ctx context.Context, target string, reporter sources.UnitReporter, listRepos repoLister, listOpts repoListOptions) error {
logger := ctx.Logger().WithValues("target", target)
opts := listOpts.getListOptions()
@ -215,8 +216,10 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
repoName, repoURL := r.GetFullName(), r.GetCloneURL()
s.totalRepoSize += r.GetSize()
s.filteredRepoCache.Set(repoName, repoURL)
s.cacheRepoInfo(r)
if err := reporter.UnitOk(ctx, RepoUnit{name: repoName, url: repoURL}); err != nil {
return err
}
logger.V(3).Info("repo attributes", "name", repoName, "kb_size", r.GetSize(), "repo_url", repoURL)
}