diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index d427eb791..ea74c0aa7 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -419,12 +419,19 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s } // Chunks emits chunks of bytes over a channel. -func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error { +func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error { apiEndpoint := s.conn.Endpoint if len(apiEndpoint) == 0 || endsWithGithub.MatchString(apiEndpoint) { apiEndpoint = "https://api.github.com" } + // If targets are provided, we're only scanning the data in those targets. + // Otherwise, we're scanning all data. + // This allows us to only scan the commit where a vulnerability was found. + if len(targets) > 0 { + return s.scanTargets(ctx, targets, chunksChan) + } + // Reset consumption and rate limit metrics on each run. githubNumRateLimitEncountered.WithLabelValues(s.name).Set(0) githubSecondsSpentRateLimited.WithLabelValues(s.name).Set(0) @@ -1484,6 +1491,60 @@ func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments return nil } +func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, chunksChan chan *sources.Chunk) error { + for _, tgt := range targets { + if err := s.scanTarget(ctx, tgt, chunksChan); err != nil { + ctx.Logger().Error(err, "error scanning target") + } + } + + return nil +} + +func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, chunksChan chan *sources.Chunk) error { + metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github) + if !ok { + return fmt.Errorf("unable to cast metadata type for targetted scan") + } + meta := metaType.Github + + u, err := url.Parse(meta.GetLink()) + if err != nil { + return fmt.Errorf("unable to parse GitHub URL: %w", err) + } + + // The owner is the second segment and the repo is the third segment of the path. + // Ex: https://github.com/owner/repo/..... + segments := strings.Split(u.Path, "/") + if len(segments) < 3 { + return fmt.Errorf("invalid GitHub URL") + } + + qry := commitQuery{ + repo: segments[2], + owner: segments[1], + sha: meta.GetCommit(), + filename: meta.GetFile(), + } + res, err := s.getDiffForFileInCommit(ctx, qry) + if err != nil { + return err + } + chunk := &sources.Chunk{ + SourceType: s.Type(), + SourceName: s.name, + SourceID: s.SourceID(), + JobID: s.JobID(), + Data: []byte(res), + SourceMetadata: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{Github: meta}, + }, + Verify: s.verify, + } + + return common.CancellableWrite(ctx, chunksChan, chunk) +} + func removeURLAndSplit(url string) []string { trimmedURL := strings.TrimPrefix(url, "https://") trimmedURL = strings.TrimSuffix(trimmedURL, ".git") diff --git a/pkg/sources/github/github_integration_test.go b/pkg/sources/github/github_integration_test.go index e30f47d8a..d76b5b3f9 100644 --- a/pkg/sources/github/github_integration_test.go +++ b/pkg/sources/github/github_integration_test.go @@ -838,3 +838,125 @@ func githubCommentCheckFunc(gotChunk, wantChunk *sources.Chunk, i int, t *testin // }) // } // } + +func TestSource_Chunks_TargetedScan(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3000) + defer cancel() + + secret, err := common.GetTestSecret(ctx) + if err != nil { + t.Fatal(fmt.Errorf("failed to access secret: %v", err)) + } + + githubToken := secret.MustGetField("GITHUB_TOKEN") + + type init struct { + name string + verify bool + connection *sourcespb.GitHub + queryCriteria *source_metadatapb.MetaData + } + tests := []struct { + name string + init init + wantChunks int + }{ + { + name: "targeted scan, one file in small commit", + init: init{ + name: "test source", + connection: &sourcespb.GitHub{Credential: &sourcespb.GitHub_Token{Token: githubToken}}, + queryCriteria: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Repository: "test_keys", + Link: "https://github.com/trufflesecurity/test_keys/blob/fbc14303ffbf8fb1c2c1914e8dda7d0121633aca/keys#L4", + Commit: "fbc14303ffbf8fb1c2c1914e8dda7d0121633aca", + File: "keys", + }, + }, + }, + }, + wantChunks: 1, + }, + { + name: "targeted scan, one file in med commit", + init: init{ + name: "test source", + connection: &sourcespb.GitHub{Credential: &sourcespb.GitHub_Token{Token: githubToken}}, + queryCriteria: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Repository: "https://github.com/trufflesecurity/trufflehog.git", + Link: "https://github.com/trufflesecurity/trufflehog/blob/33eed42e17fda8b1a66feaeafcd57efccff26c11/pkg/sources/s3/s3_test.go#L78", + Commit: "33eed42e17fda8b1a66feaeafcd57efccff26c11", + File: "pkg/sources/s3/s3_test.go", + }, + }, + }, + }, + wantChunks: 1, + }, + { + name: "no file in commit", + init: init{ + name: "test source", + connection: &sourcespb.GitHub{Credential: &sourcespb.GitHub_Token{Token: githubToken}}, + queryCriteria: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Repository: "test_keys", + Link: "https://github.com/trufflesecurity/test_keys/blob/fbc14303ffbf8fb1c2c1914e8dda7d0121633aca/keys#L4", + Commit: "fbc14303ffbf8fb1c2c1914e8dda7d0121633aca", + File: "not-the-file", + }, + }, + }, + }, + wantChunks: 0, + }, + { + name: "invalid query criteria, malformed link", + init: init{ + name: "test source", + connection: &sourcespb.GitHub{Credential: &sourcespb.GitHub_Token{Token: githubToken}}, + queryCriteria: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Repository: "test_keys", + Link: "malformed-link", + Commit: "fbc14303ffbf8fb1c2c1914e8dda7d0121633aca", + File: "not-the-file", + }, + }, + }, + }, + wantChunks: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := Source{} + + conn, err := anypb.New(tt.init.connection) + assert.Nil(t, err) + + err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) + assert.Nil(t, err) + + chunksCh := make(chan *sources.Chunk, 1) + go func() { + defer close(chunksCh) + err = s.Chunks(ctx, chunksCh, sources.ChunkingTarget{QueryCriteria: tt.init.queryCriteria}) + assert.Nil(t, err) + }() + + i := 0 + for range chunksCh { + i++ + } + assert.Equal(t, tt.wantChunks, i) + }) + } +} diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index 76694d409..901ab0f09 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -239,6 +239,53 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo return nil } +// commitQuery represents the details required to fetch a commit. +type commitQuery struct { + repo string + owner string + sha string + filename string +} + +// getDiffForFileInCommit retrieves the diff for a specified file in a commit. +// If the file or its diff is not found, it returns an error. +func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) { + commit, resp, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) + if handled := s.handleRateLimit(err, resp); handled { + return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err) + } + if err != nil { + return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err) + } + + if len(commit.Files) == 0 { + return "", fmt.Errorf("commit %s does not contain any files", query.sha) + } + + res := new(strings.Builder) + // Only return the diff if the file is in the commit. + for _, file := range commit.Files { + if *file.Filename != query.filename { + continue + } + + if file.Patch == nil { + return "", fmt.Errorf("commit %s file %s does not have a diff", query.sha, query.filename) + } + + if _, err := res.WriteString(*file.Patch); err != nil { + return "", fmt.Errorf("buffer write error for commit %s file %s: %w", query.sha, query.filename, err) + } + res.WriteString("\n") + } + + if res.Len() == 0 { + return "", fmt.Errorf("commit %s does not contain patch for file %s", query.sha, query.filename) + } + + return res.String(), nil +} + func (s *Source) normalizeRepo(repo string) (string, error) { // If there's a '/', assume it's a URL and try to normalize it. if strings.ContainsRune(repo, '/') { diff --git a/pkg/sources/sources.go b/pkg/sources/sources.go index 5aa719c19..968644bb3 100644 --- a/pkg/sources/sources.go +++ b/pkg/sources/sources.go @@ -42,7 +42,7 @@ type Chunk struct { // without processing the entire dataset. type ChunkingTarget struct { // QueryCriteria represents specific parameters or conditions to target the chunking process. - QueryCriteria source_metadatapb.MetaData + QueryCriteria *source_metadatapb.MetaData } // Source defines the interface required to implement a source chunker.