Refactor GitHub source (#2379)

* refactor(github): cleanup logic

* fix(github): lookup wikis per-repo

* refactor(github): change scanErrs.String output

---------

Co-authored-by: Bill Rich <bill.rich@gmail.com>
This commit is contained in:
Richard Gomez 2024-03-21 17:07:39 -04:00 committed by GitHub
parent eb0d4ae5d2
commit 80e8a67c2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 395 additions and 427 deletions

View file

@ -2,7 +2,7 @@ package sources
import (
"errors"
"fmt"
"strings"
"sync"
)
@ -35,7 +35,17 @@ func (s *ScanErrors) Count() uint64 {
func (s *ScanErrors) String() string {
s.mu.RLock()
defer s.mu.RUnlock()
return fmt.Sprintf("%v", s.errors)
var sb strings.Builder
sb.WriteString("[")
for i, err := range s.errors {
sb.WriteString(`"` + err.Error() + `"`)
if i < len(s.errors)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
func (s *ScanErrors) Errors() error {

View file

@ -48,6 +48,7 @@ const (
type Source struct {
name string
// Protects the user and token.
userMu sync.Mutex
githubUser string
@ -56,14 +57,12 @@ type Source struct {
sourceID sources.SourceID
jobID sources.JobID
verify bool
repos []string
orgsCache cache.Cache
memberCache map[string]struct{}
repos []string
filteredRepoCache *filteredRepoCache
// repos that _probably_ have wikis (see the comment on hasWiki).
reposWithWikis map[string]struct{}
memberCache map[string]struct{}
repoSizes repoSize
totalRepoSize int // total size of all repos in kb
repoInfoCache repoInfoCache
totalRepoSize int // total size of all repos in kb
useCustomContentWriter bool
git *git.Git
@ -79,12 +78,10 @@ type Source struct {
resumeInfoSlice []string
apiClient *github.Client
mu sync.Mutex // protects the visibility maps
publicMap map[string]source_metadatapb.Visibility
includePRComments bool
includeIssueComments bool
includeGistComments bool
sources.Progress
sources.CommonSourceUnitUnmarshaller
}
@ -123,27 +120,6 @@ func (s *Source) JobID() sources.JobID {
return s.jobID
}
type repoSize struct {
mu sync.RWMutex
repoSizes map[string]int // size in kb of each repo
}
func (r *repoSize) addRepo(repo string, size int) {
r.mu.Lock()
defer r.mu.Unlock()
r.repoSizes[repo] = size
}
func (r *repoSize) getRepo(repo string) int {
r.mu.RLock()
defer r.mu.RUnlock()
return r.repoSizes[repo]
}
func newRepoSize() repoSize {
return repoSize{repoSizes: make(map[string]int)}
}
// filteredRepoCache is a wrapper around cache.Cache that filters out repos
// based on include and exclude globs.
type filteredRepoCache struct {
@ -209,6 +185,11 @@ func (c *filteredRepoCache) includeRepo(s string) bool {
// Init returns an initialized GitHub source.
func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
err := git.CmdCheck()
if err != nil {
return err
}
s.log = aCtx.Logger()
s.name = name
@ -222,20 +203,22 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
s.apiClient = github.NewClient(s.httpClient)
var conn sourcespb.GitHub
err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
return fmt.Errorf("error unmarshalling connection: %w", err)
}
s.conn = &conn
s.orgsCache = memory.New()
for _, org := range s.conn.Organizations {
s.orgsCache.Set(org, org)
}
s.memberCache = make(map[string]struct{})
s.filteredRepoCache = s.newFilteredRepoCache(memory.New(),
append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...),
s.conn.GetIgnoreRepos(),
)
s.reposWithWikis = make(map[string]struct{})
s.memberCache = make(map[string]struct{})
s.repoSizes = newRepoSize()
s.repos = s.conn.Repositories
for _, repo := range s.repos {
r, err := s.normalizeRepo(repo)
@ -245,28 +228,17 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
}
s.filteredRepoCache.Set(repo, r)
}
s.repoInfoCache = newRepoInfoCache()
s.includeIssueComments = s.conn.IncludeIssueComments
s.includePRComments = s.conn.IncludePullRequestComments
s.includeGistComments = s.conn.IncludeGistComments
s.orgsCache = memory.New()
for _, org := range s.conn.Organizations {
s.orgsCache.Set(org, org)
}
// Head or base should only be used with incoming webhooks
if (len(s.conn.Head) > 0 || len(s.conn.Base) > 0) && len(s.repos) != 1 {
return fmt.Errorf("cannot specify head or base with multiple repositories")
}
err = git.CmdCheck()
if err != nil {
return err
}
s.publicMap = map[string]source_metadatapb.Visibility{}
cfg := &git.Config{
SourceName: s.name,
JobID: s.jobID,
@ -358,82 +330,22 @@ func checkGitHubConnection(ctx context.Context, client *github.Client) error {
return err
}
func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility source_metadatapb.Visibility) {
func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metadatapb.Visibility {
// It isn't possible to get the visibility of a wiki.
// We must use the visibility of the corresponding repository.
if strings.HasSuffix(repoURL, ".wiki.git") {
repoURL = strings.TrimSuffix(repoURL, ".wiki.git") + ".git"
}
s.mu.Lock()
visibility, ok := s.publicMap[repoURL]
s.mu.Unlock()
if ok {
return visibility
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
ctx.Logger().Error(err, "failed to get repository visibility")
return source_metadatapb.Visibility_unknown
}
visibility = source_metadatapb.Visibility_public
defer func() {
s.mu.Lock()
s.publicMap[repoURL] = visibility
s.mu.Unlock()
}()
logger := s.log.WithValues("repo", repoURL)
if _, unauthenticated := s.conn.GetCredential().(*sourcespb.GitHub_Unauthenticated); unauthenticated {
logger.V(3).Info("assuming unauthenticated scan has public visibility")
return source_metadatapb.Visibility_public
}
logger.V(2).Info("Checking public status")
u, err := url.Parse(repoURL)
if err != nil {
logger.Error(err, "Could not parse repository URL.")
return
}
urlPathParts := strings.Split(u.Path, "/")
switch len(urlPathParts) {
case 2:
// Check if repoURL is a gist.
var gist *github.Gist
repoName := urlPathParts[1]
repoName = strings.TrimSuffix(repoName, ".git")
for {
gist, _, err = s.apiClient.Gists.Get(ctx, repoName)
if !s.handleRateLimit(err) {
break
}
}
if err != nil || gist == nil {
logger.Error(err, "Could not get Github repository")
return
}
if !(*gist.Public) {
visibility = source_metadatapb.Visibility_private
}
case 3:
var repo *github.Repository
owner := urlPathParts[1]
repoName := urlPathParts[2]
repoName = strings.TrimSuffix(repoName, ".git")
for {
repo, _, err = s.apiClient.Repositories.Get(ctx, owner, repoName)
if !s.handleRateLimit(err) {
break
}
}
if err != nil || repo == nil {
logger.Error(err, "Could not get Github repository")
return
}
if *repo.Private {
visibility = source_metadatapb.Visibility_private
}
default:
logger.Error(fmt.Errorf("unexpected number of parts"), "RepoURL should split into 2 or 3 parts",
"got", len(urlPathParts),
)
}
return
return repoInfo.visibility
}
const cloudEndpoint = "https://api.github.com"
@ -498,6 +410,19 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli
ctx.Logger().Error(fmt.Errorf("type assertion failed"), "unexpected value in cache", "repo", repo)
continue
}
_, urlParts, err := getRepoURLParts(r)
if err != nil {
ctx.Logger().Error(err, "failed to parse repository URL")
continue
}
ghRepo, _, err := s.apiClient.Repositories.Get(ctx, urlParts[1], urlParts[2])
if err != nil {
ctx.Logger().Error(err, "failed to fetch repository")
continue
}
s.cacheRepoInfo(ghRepo)
s.repos = append(s.repos, r)
}
githubReposEnumerated.WithLabelValues(s.name).Set(float64(len(s.repos)))
@ -660,7 +585,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
}
if s.conn.ScanUsers {
s.log.Info("Adding repos", "members", len(s.memberCache), "orgs", s.orgsCache.Count())
s.log.Info("Adding repos", "orgs", s.orgsCache.Count(), "members", len(s.memberCache))
s.addReposForMembers(ctx)
return nil
}
@ -757,7 +682,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
}
func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
var scannedCount uint64
var scannedCount uint64 = 1
s.log.V(2).Info("Found repos to scan", "count", len(s.repos))
@ -793,36 +718,41 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
}
// Scan the repository
repoInfo, ok := s.repoInfoCache.get(repoURL)
if !ok {
// This should never happen.
err := fmt.Errorf("no repoInfo for URL: %s", repoURL)
s.log.Error(err, "failed to scan repository")
return nil
}
repoCtx := context.WithValues(ctx, "repo", repoURL)
duration, err := s.cloneAndScanRepo(repoCtx, installationClient, repoURL, chunksChan)
duration, err := s.cloneAndScanRepo(repoCtx, installationClient, repoURL, repoInfo, chunksChan)
if err != nil {
scanErrs.Add(err)
return nil
}
// Scan the wiki, if enabled, and the repo has one.
if s.conn.IncludeWikis {
if _, ok := s.reposWithWikis[repoURL]; ok {
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)
if s.conn.IncludeWikis && repoInfo.hasWiki && s.wikiIsReachable(ctx, repoURL) {
wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git"
wikiCtx := context.WithValue(ctx, "repo", wikiURL)
_, err := s.cloneAndScanRepo(wikiCtx, installationClient, wikiURL, chunksChan)
if err != nil {
scanErrs.Add(err)
// Don't return, it still might be possible to scan comments.
}
_, err := s.cloneAndScanRepo(wikiCtx, installationClient, wikiURL, repoInfo, chunksChan)
if err != nil {
scanErrs.Add(fmt.Errorf("error scanning wiki: %s", wikiURL))
// Don't return, it still might be possible to scan comments.
}
}
// Scan comments, if enabled.
if s.includeGistComments || s.includeIssueComments || s.includePRComments {
if err = s.scanComments(ctx, repoURL, chunksChan); err != nil {
if err = s.scanComments(repoCtx, repoURL, repoInfo, chunksChan); err != nil {
scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err))
return nil
}
}
ctx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d repos", scannedCount, len(s.repos)), "duration_seconds", duration)
repoCtx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d repos", scannedCount, len(s.repos)), "duration_seconds", duration)
githubReposScanned.WithLabelValues(s.name).Inc()
atomic.AddUint64(&scannedCount, 1)
return nil
@ -831,14 +761,14 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
_ = s.jobPool.Wait()
if scanErrs.Count() > 0 {
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs)
s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs.String())
}
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed Github scan", "")
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
return nil
}
func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, repoURL string, chunksChan chan *sources.Chunk) (time.Duration, error) {
func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, repoURL string, repoInfo repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) {
var duration time.Duration
ctx.Logger().V(2).Info("attempting to clone repo")
@ -853,9 +783,8 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, re
// Repo size is not collected for wikis.
var logger logr.Logger
if !strings.HasSuffix(repoURL, ".wiki.git") {
repoSize := s.repoSizes.getRepo(repoURL)
logger = ctx.Logger().WithValues("repo_size_kb", repoSize)
if !strings.HasSuffix(repoURL, ".wiki.git") && repoInfo.size > 0 {
logger = ctx.Logger().WithValues("repo_size_kb", repoInfo.size)
} else {
logger = ctx.Logger()
}
@ -963,9 +892,21 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
if err != nil {
return fmt.Errorf("could not list gists for user %s: %w", user, err)
}
for _, gist := range gists {
s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL())
info := repoInfo{
owner: gist.GetOwner().GetLogin(),
}
if gist.GetPublic() {
info.visibility = source_metadatapb.Visibility_public
} else {
info.visibility = source_metadatapb.Visibility_private
}
s.repoInfoCache.put(gist.GetGitPullURL(), info)
}
if res == nil || res.NextPage == 0 {
break
}
@ -1018,9 +959,11 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
s.log.Error(err, "could not list all organizations")
return
}
if len(orgs) == 0 {
break
}
lastOrgID := *orgs[len(orgs)-1].ID
s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID))
orgOpts.Since = lastOrgID
@ -1055,9 +998,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
logger.Error(err, "Could not list organizations")
return
}
if resp == nil {
break
}
logger.V(2).Info("Listed orgs", "page", orgOpts.Page, "last_page", resp.LastPage)
for _, org := range orgs {
if org.Login == nil {
@ -1089,9 +1030,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
if err != nil || len(members) == 0 {
return fmt.Errorf("could not list organization members: account may not have access to list organization members %w", err)
}
if res == nil {
break
}
logger.V(2).Info("Listed members", "page", opts.Page, "last_page", res.LastPage)
for _, m := range members {
usr := m.Login
@ -1126,30 +1065,60 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri
s.SetProgressComplete(index+offset, len(s.repos)+offset, fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo)
}
const initialPage = 1 // page to start listing from
func (s *Source) scanComments(ctx context.Context, repoPath string, chunksChan chan *sources.Chunk) error {
// Support ssh and https URLs
repoURL, err := git.GitURLParse(repoPath)
func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
urlString, urlParts, err := getRepoURLParts(repoPath)
if err != nil {
return err
}
trimmedURL := removeURLAndSplit(repoURL.String())
if repoURL.Host == "gist.github.com" && s.includeGistComments {
return s.processGistComments(ctx, repoPath, trimmedURL, repoURL, chunksChan)
if s.includeGistComments && urlParts[0] == "gist.github.com" {
return s.processGistComments(ctx, urlString, urlParts, repoInfo, chunksChan)
} else if s.includeIssueComments || s.includePRComments {
return s.processRepoComments(ctx, repoInfo, chunksChan)
}
return s.processRepoComments(ctx, repoPath, trimmedURL, repoURL, chunksChan)
return nil
}
func (s *Source) processGistComments(ctx context.Context, repoPath string, trimmedURL []string, repoURL *url.URL, chunksChan chan *sources.Chunk) error {
ctx.Logger().V(2).Info("scanning github gist comments", "repository", repoPath)
// GitHub Gist URL.
gistID, err := extractGistID(trimmedURL)
// trimURLAndSplit removes extraneous information from the |url| and splits it into segments.
// This is typically 3 segments: host, owner, and name/ID; however, Gists have some edge cases.
//
// Examples:
// - "https://github.com/trufflesecurity/trufflehog" => ["github.com", "trufflesecurity", "trufflehog"]
// - "https://gist.github.com/nat/5fdbb7f945d121f197fb074578e53948" => ["gist.github.com", "nat", "5fdbb7f945d121f197fb074578e53948"]
// - "https://gist.github.com/ff0e5e8dc8ec22f7a25ddfc3492d3451.git" => ["gist.github.com", "ff0e5e8dc8ec22f7a25ddfc3492d3451"]
func getRepoURLParts(repoURL string) (string, []string, error) {
// Support ssh and https URLs.
url, err := git.GitURLParse(repoURL)
if err != nil {
return err
return "", []string{}, err
}
// Remove the user information.
// e.g., `git@github.com` -> `github.com`
if url.User != nil {
url.User = nil
}
urlString := url.String()
trimmedURL := strings.TrimPrefix(urlString, url.Scheme+"://")
trimmedURL = strings.TrimSuffix(trimmedURL, ".git")
splitURL := strings.Split(trimmedURL, "/")
if len(splitURL) < 2 || len(splitURL) > 3 {
return "", []string{}, fmt.Errorf("invalid repository or gist URL (%s): length of URL segments should be 2 or 3", urlString)
}
return urlString, splitURL, nil
}
const initialPage = 1 // page to start listing from
func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
ctx.Logger().V(2).Info("Scanning GitHub Gist comments")
// GitHub Gist URL.
gistID := extractGistID(urlParts)
options := &github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
@ -1157,13 +1126,13 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm
for {
comments, _, err := s.apiClient.Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(err) {
break
continue
}
if err != nil {
return err
}
if err = s.chunkGistComments(ctx, repoURL.String(), comments, chunksChan); err != nil {
if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, chunksChan); err != nil {
return err
}
@ -1175,17 +1144,47 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm
return nil
}
func extractGistID(url []string) (string, error) {
if len(url) < 2 || len(url) > 3 {
return "", fmt.Errorf("failed to parse Gist URL: length of trimmedURL should be 2 or 3")
func extractGistID(url []string) string {
return url[len(url)-1]
}
func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo repoInfo, comments []*github.GistComment, chunksChan chan *sources.Chunk) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Github{
Github: &source_metadatapb.Github{
Link: sanitizer.UTF8(comment.GetURL()),
Username: sanitizer.UTF8(comment.GetUser().GetLogin()),
Email: sanitizer.UTF8(comment.GetUser().GetEmail()),
Repository: sanitizer.UTF8(gistURL),
Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()),
Visibility: gistInfo.visibility,
},
},
},
Data: []byte(sanitizer.UTF8(comment.GetBody())),
Verify: s.verify,
}
select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
}
}
return url[len(url)-1], nil
return nil
}
// Note: these can't be consts because the address is needed when using with the GitHub library.
var (
// sortType defines the criteria for sorting comments.
// By default comments are sorted by their creation date.
// By default, comments are sorted by their creation date.
sortType = "created"
// directionType defines the direction of sorting.
// "desc" means comments will be sorted in descending order, showing the latest comments first.
@ -1197,34 +1196,9 @@ var (
state = "all"
)
type repoInfo struct {
owner string
repo string
repoPath string
visibility source_metadatapb.Visibility
}
func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimmedURL []string, repoURL *url.URL, chunksChan chan *sources.Chunk) error {
// Normal repository URL (https://github.com/<owner>/<repo>).
if len(trimmedURL) < 3 {
return fmt.Errorf("url missing owner and/or repo: '%s'", repoURL.String())
}
owner := trimmedURL[1]
repo := trimmedURL[2]
if !(s.includeIssueComments || s.includePRComments) {
return nil
}
repoInfo := repoInfo{
owner: owner,
repo: repo,
repoPath: repoPath,
visibility: s.visibilityOf(ctx, repoPath),
}
func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
if s.includeIssueComments {
ctx.Logger().V(2).Info("scanning github issues", "repository", repoInfo.repoPath)
ctx.Logger().V(2).Info("Scanning issues")
if err := s.processIssues(ctx, repoInfo, chunksChan); err != nil {
return err
}
@ -1234,7 +1208,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimm
}
if s.includePRComments {
ctx.Logger().V(2).Info("scanning github pull requests", "repository", repoInfo.repoPath)
ctx.Logger().V(2).Info("Scanning pull requests")
if err := s.processPRs(ctx, repoInfo, chunksChan); err != nil {
return err
}
@ -1247,7 +1221,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimm
}
func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error {
func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
bodyTextsOpts := &github.IssueListByRepoOptions{
Sort: sortType,
Direction: directionType,
@ -1259,16 +1233,16 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch
}
for {
issues, _, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts)
issues, _, err := s.apiClient.Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
if s.handleRateLimit(err) {
break
continue
}
if err != nil {
return err
}
if err = s.chunkIssues(ctx, info, issues, chunksChan); err != nil {
if err = s.chunkIssues(ctx, repoInfo, issues, chunksChan); err != nil {
return err
}
@ -1281,106 +1255,6 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch
return nil
}
func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error {
issueOpts := &github.IssueListCommentsOptions{
Sort: &sortType,
Direction: &directionType,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
issueComments, _, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts)
if s.handleRateLimit(err) {
break
}
if err != nil {
return err
}
if err = s.chunkIssueComments(ctx, info, issueComments, chunksChan); err != nil {
return err
}
issueOpts.ListOptions.Page++
if len(issueComments) < defaultPagination {
break
}
}
return nil
}
func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error {
prOpts := &github.PullRequestListOptions{
Sort: sortType,
Direction: directionType,
State: state,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
prs, _, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts)
if s.handleRateLimit(err) {
break
}
if err != nil {
return err
}
if err = s.chunkPullRequests(ctx, info, prs, chunksChan); err != nil {
return err
}
prOpts.ListOptions.Page++
if len(prs) < defaultPagination {
break
}
}
return nil
}
func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error {
prOpts := &github.PullRequestListCommentsOptions{
Sort: sortType,
Direction: directionType,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
prComments, _, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts)
if s.handleRateLimit(err) {
break
}
if err != nil {
return err
}
if err = s.chunkPullRequestComments(ctx, info, prComments, chunksChan); err != nil {
return err
}
prOpts.ListOptions.Page++
if len(prComments) < defaultPagination {
break
}
}
return nil
}
func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error {
for _, issue := range issues {
@ -1401,7 +1275,7 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
Link: sanitizer.UTF8(issue.GetHTMLURL()),
Username: sanitizer.UTF8(issue.GetUser().GetLogin()),
Email: sanitizer.UTF8(issue.GetUser().GetEmail()),
Repository: sanitizer.UTF8(repoInfo.repo),
Repository: sanitizer.UTF8(repoInfo.fullName),
Timestamp: sanitizer.UTF8(issue.GetCreatedAt().String()),
Visibility: repoInfo.visibility,
},
@ -1420,6 +1294,37 @@ func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*g
return nil
}
func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
issueOpts := &github.IssueListCommentsOptions{
Sort: &sortType,
Direction: &directionType,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
issueComments, _, err := s.apiClient.Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
if s.handleRateLimit(err) {
continue
}
if err != nil {
return err
}
if err = s.chunkIssueComments(ctx, repoInfo, issueComments, chunksChan); err != nil {
return err
}
issueOpts.ListOptions.Page++
if len(issueComments) < defaultPagination {
break
}
}
return nil
}
func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
@ -1434,7 +1339,7 @@ func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comm
Link: sanitizer.UTF8(comment.GetHTMLURL()),
Username: sanitizer.UTF8(comment.GetUser().GetLogin()),
Email: sanitizer.UTF8(comment.GetUser().GetEmail()),
Repository: sanitizer.UTF8(repoInfo.repo),
Repository: sanitizer.UTF8(repoInfo.fullName),
Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()),
Visibility: repoInfo.visibility,
},
@ -1453,6 +1358,104 @@ func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comm
return nil
}
func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
prOpts := &github.PullRequestListOptions{
Sort: sortType,
Direction: directionType,
State: state,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
prs, _, err := s.apiClient.PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
if s.handleRateLimit(err) {
continue
}
if err != nil {
return err
}
if err = s.chunkPullRequests(ctx, repoInfo, prs, chunksChan); err != nil {
return err
}
prOpts.ListOptions.Page++
if len(prs) < defaultPagination {
break
}
}
return nil
}
func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, chunksChan chan *sources.Chunk) error {
prOpts := &github.PullRequestListCommentsOptions{
Sort: sortType,
Direction: directionType,
ListOptions: github.ListOptions{
PerPage: defaultPagination,
Page: initialPage,
},
}
for {
prComments, _, err := s.apiClient.PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
if s.handleRateLimit(err) {
continue
}
if err != nil {
return err
}
if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, chunksChan); err != nil {
return err
}
prOpts.ListOptions.Page++
if len(prComments) < defaultPagination {
break
}
}
return nil
}
func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error {
for _, pr := range prs {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Github{
Github: &source_metadatapb.Github{
Link: sanitizer.UTF8(pr.GetHTMLURL()),
Username: sanitizer.UTF8(pr.GetUser().GetLogin()),
Email: sanitizer.UTF8(pr.GetUser().GetEmail()),
Repository: sanitizer.UTF8(repoInfo.fullName),
Timestamp: sanitizer.UTF8(pr.GetCreatedAt().String()),
Visibility: repoInfo.visibility,
},
},
},
Data: []byte(sanitizer.UTF8(pr.GetTitle() + "\n" + pr.GetBody())),
Verify: s.verify,
}
select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
}
}
return nil
}
func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
@ -1467,7 +1470,7 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo
Link: sanitizer.UTF8(comment.GetHTMLURL()),
Username: sanitizer.UTF8(comment.GetUser().GetLogin()),
Email: sanitizer.UTF8(comment.GetUser().GetEmail()),
Repository: sanitizer.UTF8(repoInfo.repo),
Repository: sanitizer.UTF8(repoInfo.fullName),
Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()),
Visibility: repoInfo.visibility,
},
@ -1486,73 +1489,6 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo
return nil
}
func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error {
for _, pr := range prs {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Github{
Github: &source_metadatapb.Github{
Link: sanitizer.UTF8(pr.GetHTMLURL()),
Username: sanitizer.UTF8(pr.GetUser().GetLogin()),
Email: sanitizer.UTF8(pr.GetUser().GetEmail()),
Repository: sanitizer.UTF8(repoInfo.repo),
Timestamp: sanitizer.UTF8(pr.GetCreatedAt().String()),
Visibility: repoInfo.visibility,
},
},
},
Data: []byte(sanitizer.UTF8(pr.GetTitle() + "\n" + pr.GetBody())),
Verify: s.verify,
}
select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
}
}
return nil
}
func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments []*github.GistComment, chunksChan chan *sources.Chunk) error {
for _, comment := range comments {
// Create chunk and send it to the channel.
chunk := &sources.Chunk{
SourceName: s.name,
SourceID: s.SourceID(),
SourceType: s.Type(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Github{
Github: &source_metadatapb.Github{
Link: sanitizer.UTF8(comment.GetURL()),
Username: sanitizer.UTF8(comment.GetUser().GetLogin()),
Email: sanitizer.UTF8(comment.GetUser().GetEmail()),
Repository: sanitizer.UTF8(gistUrl),
Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()),
// TODO: Fetching this requires making an additional API call. We may want to include this in the future.
// Visibility: s.visibilityOf(ctx, repoPath),
},
},
},
Data: []byte(sanitizer.UTF8(comment.GetBody())),
Verify: s.verify,
}
select {
case <-ctx.Done():
return ctx.Err()
case chunksChan <- chunk:
}
}
return nil
}
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, chunksChan chan *sources.Chunk) error {
for _, tgt := range targets {
if err := s.scanTarget(ctx, tgt, chunksChan); err != nil {
@ -1607,11 +1543,3 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget,
return common.CancellableWrite(ctx, chunksChan, chunk)
}
func removeURLAndSplit(url string) []string {
trimmedURL := strings.TrimPrefix(url, "https://")
trimmedURL = strings.TrimSuffix(trimmedURL, ".git")
splitURL := strings.Split(trimmedURL, "/")
return splitURL
}

View file

@ -53,11 +53,11 @@ func TestSource_Token(t *testing.T) {
}
s := Source{
conn: conn,
httpClient: common.SaneHttpClient(),
log: logr.Discard(),
memberCache: map[string]struct{}{},
repoSizes: newRepoSize(),
conn: conn,
httpClient: common.SaneHttpClient(),
log: logr.Discard(),
memberCache: map[string]struct{}{},
repoInfoCache: newRepoInfoCache(),
}
s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), nil, nil)

View file

@ -8,10 +8,8 @@ import (
"encoding/pem"
"fmt"
"net/http"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
@ -27,7 +25,6 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
func createTestSource(src *sourcespb.GitHub) (*Source, *anypb.Any) {
@ -713,34 +710,23 @@ func Test_scan_SetProgressComplete(t *testing.T) {
}
}
func TestProcessRepoComments(t *testing.T) {
tests := []struct {
name string
trimmedURL []string
wantErr bool
}{
{
name: "URL with missing owner and/or repo",
trimmedURL: []string{"https://github.com/"},
wantErr: true,
},
{
name: "URL with complete owner and repo",
trimmedURL: []string{"https://github.com/", "owner", "repo"},
wantErr: false,
},
// TODO: Add more test cases to cover other scenarios.
func TestGetRepoURLParts(t *testing.T) {
tests := []string{
"https://github.com/trufflesecurity/trufflehog.git",
"git+https://github.com/trufflesecurity/trufflehog.git",
//"git@github.com:trufflesecurity/trufflehog.git",
"ssh://github.com/trufflesecurity/trufflehog.git",
"ssh://git@github.com/trufflesecurity/trufflehog.git",
"git+ssh://git@github.com/trufflesecurity/trufflehog.git",
"git://github.com/trufflesecurity/trufflehog.git",
}
expected := []string{"github.com", "trufflesecurity", "trufflehog"}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Source{}
repoURL, _ := url.Parse(strings.Join(tt.trimmedURL, "/"))
chunksChan := make(chan *sources.Chunk)
err := s.processRepoComments(context.Background(), "repoPath", tt.trimmedURL, repoURL, chunksChan)
assert.Equal(t, tt.wantErr, err != nil)
})
_, parts, err := getRepoURLParts(tt)
if err != nil {
t.Fatalf("failed: %v", err)
}
assert.Equal(t, expected, parts)
}
}
@ -748,17 +734,15 @@ func TestGetGistID(t *testing.T) {
tests := []struct {
trimmedURL []string
expected string
err bool
}{
{[]string{"https://gist.github.com", "12345"}, "12345", false},
{[]string{"https://gist.github.com", "owner", "12345"}, "12345", false},
{[]string{"https://gist.github.com"}, "", true},
{[]string{"https://gist.github.com", "owner", "12345", "extra"}, "", true},
{[]string{"https://gist.github.com", "12345"}, "12345"},
{[]string{"https://gist.github.com", "owner", "12345"}, "12345"},
{[]string{"https://gist.github.com"}, ""},
{[]string{"https://gist.github.com", "owner", "12345", "extra"}, ""},
}
for _, tt := range tests {
got, err := extractGistID(tt.trimmedURL)
assert.Equal(t, tt.err, err != nil)
got := extractGistID(tt.trimmedURL)
assert.Equal(t, tt.expected, got)
}
}

View file

@ -2,19 +2,56 @@ package github
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
gogit "github.com/go-git/go-git/v5"
"github.com/google/go-github/v57/github"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
)
type repoInfoCache struct {
mu sync.RWMutex
cache map[string]repoInfo
}
func newRepoInfoCache() repoInfoCache {
return repoInfoCache{
cache: make(map[string]repoInfo),
}
}
func (r *repoInfoCache) put(repoURL string, info repoInfo) {
r.mu.Lock()
defer r.mu.Unlock()
r.cache[repoURL] = info
}
func (r *repoInfoCache) get(repoURL string) (repoInfo, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
info, ok := r.cache[repoURL]
return info, ok
}
type repoInfo struct {
owner string
name string
fullName string
hasWiki bool // the repo is _likely_ to have a wiki (see the comment on wikiIsReachable func).
size int
visibility source_metadatapb.Visibility
}
func (s *Source) cloneRepo(
ctx context.Context,
repoURL string,
@ -209,9 +246,6 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
if err != nil {
return err
}
if res == nil {
break
}
s.log.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage)
for _, r := range someRepos {
@ -228,12 +262,10 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
}
repoName, repoURL := r.GetFullName(), r.GetCloneURL()
s.repoSizes.addRepo(repoURL, r.GetSize())
s.totalRepoSize += r.GetSize()
s.filteredRepoCache.Set(repoName, repoURL)
if s.conn.GetIncludeWikis() && s.hasWiki(ctx, r, repoURL) {
s.reposWithWikis[repoURL] = struct{}{}
}
s.cacheRepoInfo(r)
logger.V(3).Info("repo attributes", "name", repoName, "kb_size", r.GetSize(), "repo_url", repoURL)
}
@ -249,13 +281,26 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo
return nil
}
// hasWiki returns true if the "has_wiki" property is true AND https://github.com/$org/$repo/wiki is not redirected.
// Unfortunately, this isn't 100% accurate. Some repositories meet both criteria yet don't have a cloneable wiki.
func (s *Source) hasWiki(ctx context.Context, repo *github.Repository, repoURL string) bool {
if !repo.GetHasWiki() {
return false
func (s *Source) cacheRepoInfo(r *github.Repository) {
info := repoInfo{
owner: r.GetOwner().GetLogin(),
name: r.GetName(),
fullName: r.GetFullName(),
hasWiki: r.GetHasWiki(),
size: r.GetSize(),
}
if r.GetPrivate() {
info.visibility = source_metadatapb.Visibility_private
} else {
info.visibility = source_metadatapb.Visibility_public
}
s.repoInfoCache.put(r.GetCloneURL(), info)
}
// wikiIsReachable returns true if https://github.com/$org/$repo/wiki is not redirected.
// Unfortunately, this isn't 100% accurate. Some repositories have `has_wiki: true` and don't redirect their wiki page,
// but still don't have a cloneable wiki.
func (s *Source) wikiIsReachable(ctx context.Context, repoURL string) bool {
wikiURL := strings.TrimSuffix(repoURL, ".git") + "/wiki"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, wikiURL, nil)
if err != nil {
@ -266,6 +311,7 @@ func (s *Source) hasWiki(ctx context.Context, repo *github.Repository, repoURL s
if err != nil {
return false
}
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()
// If the wiki is disabled, or is enabled but has no content, the request should be redirected.