diff --git a/pkg/engine/git.go b/pkg/engine/git.go index 370b4642a..d6786e190 100644 --- a/pkg/engine/git.go +++ b/pkg/engine/git.go @@ -4,11 +4,13 @@ import ( "fmt" "runtime" + "github.com/go-errors/errors" gogit "github.com/go-git/go-git/v5" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" - "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" @@ -22,16 +24,6 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { git.ScanOptionLogOptions(logOptions), } - options := &gogit.PlainOpenOptions{ - DetectDotGit: true, - EnableDotGitCommonDir: true, - } - - repo, err := gogit.PlainOpenWithOptions(c.RepoPath, options) - if err != nil { - return fmt.Errorf("could not open repo: %s: %w", c.RepoPath, err) - } - if c.MaxDepth != 0 { opts = append(opts, git.ScanOptionMaxDepth(int64(c.MaxDepth))) } @@ -46,21 +38,28 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { } scanOptions := git.NewScanOptions(opts...) - gitSource := git.NewGit(sourcespb.SourceType_SOURCE_TYPE_GIT, 0, 0, "trufflehog - git", true, runtime.NumCPU(), - func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { - return &source_metadatapb.MetaData{ - Data: &source_metadatapb.MetaData_Git{ - Git: &source_metadatapb.Git{ - Commit: commit, - File: file, - Email: email, - Repository: repository, - Timestamp: timestamp, - Line: line, - }, - }, - } - }) + connection := &sourcespb.Git{ + // Using Directories here allows us to not pass any + // authentication. Also by this point, the c.RepoPath should + // still have been prepared and downloaded to a temporary + // directory if it was a URL. + Directories: []string{c.RepoPath}, + } + var conn anypb.Any + err := anypb.MarshalFrom(&conn, connection, proto.MarshalOptions{}) + if err != nil { + ctx.Logger().Error(err, "failed to marshal git connection") + return err + } + + gitSource := git.Source{} + if err := gitSource.Init(ctx, "trufflehog - git", 0, 0, true, &conn, runtime.NumCPU()); err != nil { + return errors.WrapPrefix(err, "could not init git source", 0) + } + gitSource.WithScanOptions(scanOptions) + // Don't try to clean up the provided directory. That's handled by the + // caller of ScanGit. + gitSource.WithPreserveTempDirs(true) ctx = context.WithValues(ctx, "source_type", sourcespb.SourceType_SOURCE_TYPE_GIT.String(), @@ -68,7 +67,7 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { ) e.sourcesWg.Go(func() error { defer common.RecoverWithExit(ctx) - err := gitSource.ScanRepo(ctx, repo, c.RepoPath, scanOptions, e.ChunksChan()) + err := gitSource.Chunks(ctx, e.ChunksChan()) if err != nil { return fmt.Errorf("could not scan repo: %w", err) } diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index 80417b50a..d3edb87b7 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -42,7 +42,11 @@ type Source struct { verify bool git *Git sources.Progress - conn *sourcespb.Git + conn *sourcespb.Git + scanOptions *ScanOptions + // Kludge to preserve engine.ScanGit functionality which doesn't expect + // the scanning to clean up the directory. + preserveTempDirs bool } type Git struct { @@ -92,6 +96,19 @@ func (s *Source) JobID() int64 { return s.jobId } +// WithScanOptions sets the scan options. +func (s *Source) WithScanOptions(scanOptions *ScanOptions) { + s.scanOptions = scanOptions +} + +// WithPreserveTempDirs sets whether to preserve temp directories when scanning +// the provided list of s.conn.Directories. NOTE: This is *only* for +// s.conn.Directories, not all temp directories created. This is also a kludge +// and should be refactored away. +func (s *Source) WithPreserveTempDirs(preserve bool) { + s.preserveTempDirs = preserve +} + // Init returns an initialized GitHub source. func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error { s.name = name @@ -135,8 +152,29 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error { - // TODO: refactor to remove duplicate code + if err := s.scanRepos(ctx, chunksChan); err != nil { + return err + } + if err := s.scanDirs(ctx, chunksChan); err != nil { + return err + } + totalRepos := len(s.conn.Repositories) + len(s.conn.Directories) + ctx.Logger().V(1).Info("Git source finished scanning", "repo_count", totalRepos) + s.SetProgressComplete( + totalRepos, totalRepos, + fmt.Sprintf("Completed scanning source %s", s.name), "", + ) + return nil +} + +// scanRepos scans the configured repositories in s.conn.Repositories. +func (s *Source) scanRepos(ctx context.Context, chunksChan chan *sources.Chunk) error { + if len(s.conn.Repositories) == 0 { + return nil + } + totalRepos := len(s.conn.Repositories) + len(s.conn.Directories) + // TODO: refactor to remove duplicate code switch cred := s.conn.GetCredential().(type) { case *sourcespb.Git_BasicAuth: user := cred.BasicAuth.Username @@ -153,7 +191,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, s.scanOptions, chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -172,7 +210,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, s.scanOptions, chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -191,7 +229,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, s.scanOptions, chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -201,41 +239,42 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err default: return errors.New("invalid connection type for git source") } + return nil +} +// scanDirs scans the configured directories in s.conn.Directories. +func (s *Source) scanDirs(ctx context.Context, chunksChan chan *sources.Chunk) error { + totalRepos := len(s.conn.Repositories) + len(s.conn.Directories) for i, gitDir := range s.conn.Directories { s.SetProgressComplete(len(s.conn.Repositories)+i, totalRepos, fmt.Sprintf("Repo: %s", gitDir), "") if len(gitDir) == 0 { continue } - if !strings.HasSuffix(gitDir, "git") { - // try paths instead of url - repo, err := RepoFromPath(gitDir) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) - continue + if strings.HasSuffix(gitDir, "git") { + // TODO: Figure out why we skip directories ending in "git". + continue + } + // try paths instead of url + repo, err := RepoFromPath(gitDir) + if err != nil { + ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) + continue + } + + err = func(repoPath string) error { + if !s.preserveTempDirs && strings.HasPrefix(repoPath, filepath.Join(os.TempDir(), "trufflehog")) { + defer os.RemoveAll(repoPath) } - err = func(repoPath string) error { - if strings.HasPrefix(repoPath, filepath.Join(os.TempDir(), "trufflehog")) { - defer os.RemoveAll(repoPath) - } - - return s.git.ScanRepo(ctx, repo, repoPath, NewScanOptions(), chunksChan) - }(gitDir) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) - continue - } + return s.git.ScanRepo(ctx, repo, repoPath, s.scanOptions, chunksChan) + }(gitDir) + if err != nil { + ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) + continue } } - - ctx.Logger().V(1).Info("Git source finished scanning", "repo-count", totalRepos) - s.SetProgressComplete( - totalRepos, totalRepos, - fmt.Sprintf("Completed scanning source %s", s.name), "", - ) return nil }