mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-14 08:57:40 +00:00
Merge branch 'main' into docker-tar-identification-from-filesystem-2
This commit is contained in:
commit
f623c0686f
26 changed files with 502 additions and 129 deletions
|
@ -32,7 +32,7 @@ Have questions? Feedback? Jump in slack or discord and hang out with us
|
|||
|
||||
Join our [Slack Community](https://join.slack.com/t/trufflehog-community/shared_invite/zt-pw2qbi43-Aa86hkiimstfdKH9UCpPzQ)
|
||||
|
||||
Join the [Secret Scanning Discord](https://discord.gg/sydS6AHTUP)
|
||||
Join the [Secret Scanning Discord](https://discord.gg/8Hzbrnkr7E)
|
||||
|
||||
# :tv: Demo
|
||||
|
||||
|
|
|
@ -26,9 +26,10 @@ type Context interface {
|
|||
Logger() logr.Logger
|
||||
}
|
||||
|
||||
// CancelFunc is a type alias to context.CancelFunc to allow use as if they are
|
||||
// the same types.
|
||||
// CancelFunc and CancelCauseFunc are type aliases to allow use as if they are
|
||||
// the same types as the standard library variants.
|
||||
type CancelFunc = context.CancelFunc
|
||||
type CancelCauseFunc = context.CancelCauseFunc
|
||||
|
||||
// logCtx implements Context.
|
||||
type logCtx struct {
|
||||
|
@ -68,6 +69,16 @@ func WithCancel(parent Context) (Context, context.CancelFunc) {
|
|||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithCancelCause returns context.WithCancelCause with the log object propagated.
|
||||
func WithCancelCause(parent Context) (Context, context.CancelCauseFunc) {
|
||||
ctx, cancel := context.WithCancelCause(parent)
|
||||
lCtx := logCtx{
|
||||
log: parent.Logger(),
|
||||
Context: ctx,
|
||||
}
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithDeadline returns context.WithDeadline with the log object propagated and
|
||||
// the deadline added to the structured log values.
|
||||
func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
|
||||
|
@ -79,6 +90,17 @@ func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
|
|||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithDeadlineCause returns context.WithDeadlineCause with the log object
|
||||
// propagated and the deadline added to the structured log values.
|
||||
func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithDeadlineCause(parent, d, cause)
|
||||
lCtx := logCtx{
|
||||
log: parent.Logger().WithValues("deadline", d),
|
||||
Context: ctx,
|
||||
}
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithTimeout returns context.WithTimeout with the log object propagated and
|
||||
// the timeout added to the structured log values.
|
||||
func WithTimeout(parent Context, timeout time.Duration) (Context, context.CancelFunc) {
|
||||
|
@ -90,6 +112,22 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, context.Cancel
|
|||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithTimeoutCause returns context.WithTimeoutCause with the log object
|
||||
// propagated and the timeout added to the structured log values.
|
||||
func WithTimeoutCause(parent Context, timeout time.Duration, cause error) (Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeoutCause(parent, timeout, cause)
|
||||
lCtx := logCtx{
|
||||
log: parent.Logger().WithValues("timeout", timeout),
|
||||
Context: ctx,
|
||||
}
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// Cause returns the context.Cause of the context.
|
||||
func Cause(ctx context.Context) error {
|
||||
return context.Cause(ctx)
|
||||
}
|
||||
|
||||
// WithValue returns context.WithValue with the log object propagated and
|
||||
// the value added to the structured log values (if the key is a string).
|
||||
func WithValue(parent Context, key, val any) Context {
|
||||
|
|
|
@ -3,6 +3,7 @@ package context
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -177,3 +178,10 @@ func TestRace(t *testing.T) {
|
|||
cancel()
|
||||
_ = ctx.Err()
|
||||
}
|
||||
|
||||
func TestCause(t *testing.T) {
|
||||
ctx, cancel := WithCancelCause(Background())
|
||||
err := fmt.Errorf("oh no")
|
||||
cancel(err)
|
||||
assert.Equal(t, err, Cause(ctx))
|
||||
}
|
||||
|
|
|
@ -43,8 +43,6 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result
|
|||
instanceMatches := instancePat.FindAllStringSubmatch(dataStr, -1)
|
||||
tokenMatches := accessTokenPat.FindAllStringSubmatch(dataStr, -1)
|
||||
|
||||
fmt.Printf("instanceMatches: %v\n", instanceMatches)
|
||||
|
||||
for _, instance := range instanceMatches {
|
||||
if len(instance) != 1 {
|
||||
continue
|
||||
|
|
|
@ -2,6 +2,10 @@ package engine
|
|||
|
||||
import (
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/envoyapikey"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/huggingface"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/salesforce"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/snowflake"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/trufflehogenterprise"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/abbysale"
|
||||
|
@ -287,6 +291,7 @@ import (
|
|||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/getresponse"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/getsandbox"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/github"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/github_oauth2"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/github_old"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/githubapp"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/gitlab"
|
||||
|
@ -556,7 +561,6 @@ import (
|
|||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/satismeterwritekey"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/saucelabs"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/scalewaykey"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/github_oauth2"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/scalr"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/scrapeowl"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors/scraperapi"
|
||||
|
@ -1532,6 +1536,10 @@ func DefaultDetectors() []detectors.Detector {
|
|||
couchbase.Scanner{},
|
||||
envoyapikey.Scanner{},
|
||||
github_oauth2.Scanner{},
|
||||
snowflake.Scanner{},
|
||||
huggingface.Scanner{},
|
||||
trufflehogenterprise.Scanner{},
|
||||
salesforce.Scanner{},
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package giturl
|
|||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -13,8 +14,29 @@ const (
|
|||
providerGithub provider = "Github"
|
||||
providerGitlab provider = "Gitlab"
|
||||
providerBitbucket provider = "Bitbucket"
|
||||
providerAzure provider = "Azure"
|
||||
|
||||
urlGithub = "github.com/"
|
||||
urlGitlab = "gitlab.com/"
|
||||
urlBitbucket = "bitbucket.org/"
|
||||
urlAzure = "dev.azure.com/"
|
||||
)
|
||||
|
||||
func determineProvider(repo string) provider {
|
||||
switch {
|
||||
case strings.Contains(repo, urlGithub):
|
||||
return providerGithub
|
||||
case strings.Contains(repo, urlGitlab):
|
||||
return providerGitlab
|
||||
case strings.Contains(repo, urlBitbucket):
|
||||
return providerBitbucket
|
||||
case strings.Contains(repo, urlAzure):
|
||||
return providerAzure
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeBitbucketRepo(repoURL string) (string, error) {
|
||||
if !strings.HasPrefix(repoURL, "https") {
|
||||
return "", errors.New("Bitbucket requires https repo urls: e.g. https://bitbucket.org/org/repo.git")
|
||||
|
@ -88,3 +110,35 @@ func NormalizeOrgRepoURL(provider provider, repoURL string) (string, error) {
|
|||
parsed.Path += ".git"
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// GenerateLink crafts a link to the specific file from a commit.
|
||||
// Supports GitHub, GitLab, Bitbucket, and Azure Repos.
|
||||
// If the provider supports hyperlinks to specific lines, the line number will be included.
|
||||
func GenerateLink(repo, commit, file string, line int64) string {
|
||||
switch determineProvider(repo) {
|
||||
case providerBitbucket:
|
||||
return repo[:len(repo)-4] + "/commits/" + commit
|
||||
|
||||
case providerGithub, providerGitlab:
|
||||
var baseLink string
|
||||
if file == "" {
|
||||
baseLink = repo[:len(repo)-4] + "/commit/" + commit
|
||||
} else {
|
||||
baseLink = repo[:len(repo)-4] + "/blob/" + commit + "/" + file
|
||||
if line > 0 {
|
||||
baseLink += "#L" + strconv.FormatInt(line, 10)
|
||||
}
|
||||
}
|
||||
return baseLink
|
||||
|
||||
case providerAzure:
|
||||
baseLink := repo + "?path=" + file + "&version=GB" + commit
|
||||
if line > 0 {
|
||||
baseLink += "&line=" + strconv.FormatInt(line, 10)
|
||||
}
|
||||
return baseLink
|
||||
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
)
|
||||
|
||||
func Test_NormalizeOrgRepoURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := map[string]struct {
|
||||
Provider provider
|
||||
Repo string
|
||||
|
@ -43,6 +45,8 @@ func Test_NormalizeOrgRepoURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NormalizeBitbucketRepo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := map[string]struct {
|
||||
Repo string
|
||||
Out string
|
||||
|
@ -69,6 +73,8 @@ func Test_NormalizeBitbucketRepo(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_NormalizeGitlabRepo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := map[string]struct {
|
||||
Repo string
|
||||
Out string
|
||||
|
@ -93,3 +99,73 @@ func Test_NormalizeGitlabRepo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
repo string
|
||||
commit string
|
||||
file string
|
||||
line int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "github link gen",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
file: ".gitignore",
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/blob/047b4a2ba42fc5b6c0bd535c5307434a666db5ec/.gitignore",
|
||||
},
|
||||
{
|
||||
name: "github link gen with line",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
file: ".gitignore",
|
||||
line: int64(4),
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/blob/047b4a2ba42fc5b6c0bd535c5307434a666db5ec/.gitignore#L4",
|
||||
},
|
||||
{
|
||||
name: "github link gen - no file",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/commit/047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
},
|
||||
{
|
||||
name: "Azure link gen",
|
||||
args: args{
|
||||
repo: "https://dev.azure.com/org/project/_git/repo",
|
||||
commit: "abcdef",
|
||||
file: "main.go",
|
||||
},
|
||||
want: "https://dev.azure.com/org/project/_git/repo?path=main.go&version=GBabcdef",
|
||||
},
|
||||
{
|
||||
name: "Azure link gen with line",
|
||||
args: args{
|
||||
repo: "https://dev.azure.com/org/project/_git/repo",
|
||||
commit: "abcdef",
|
||||
file: "main.go",
|
||||
line: int64(20),
|
||||
},
|
||||
want: "https://dev.azure.com/org/project/_git/repo?path=main.go&version=GBabcdef&line=20",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GenerateLink(tt.args.repo, tt.args.commit, tt.args.file, tt.args.line); got != tt.want {
|
||||
t.Errorf("generateLink() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ var (
|
|||
maxDepth = 5
|
||||
maxSize = 250 * 1024 * 1024 // 20MB
|
||||
maxTimeout = time.Duration(30) * time.Second
|
||||
|
||||
defaultBufferSize = 512
|
||||
)
|
||||
|
||||
// Ensure the Archive satisfies the interfaces at compile time.
|
||||
|
@ -85,7 +87,7 @@ func SetArchiveMaxTimeout(timeout time.Duration) {
|
|||
|
||||
// FromFile extracts the files from an archive.
|
||||
func (a *Archive) FromFile(originalCtx context.Context, data io.Reader) chan []byte {
|
||||
archiveChan := make(chan []byte, 512)
|
||||
archiveChan := make(chan []byte, defaultBufferSize)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(originalCtx, maxTimeout)
|
||||
logger := logContext.AddLogger(ctx).Logger()
|
||||
|
@ -206,29 +208,28 @@ func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
|
|||
logger.Error(err, "Panic occurred when reading archive")
|
||||
}
|
||||
}()
|
||||
fileContent := bytes.Buffer{}
|
||||
logger.V(5).Info("Remaining buffer capacity", "bytes", maxSize-a.size)
|
||||
for i := 0; i <= maxSize/512; i++ {
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
fileChunk := make([]byte, 512)
|
||||
bRead, err := reader.Read(fileChunk)
|
||||
|
||||
var fileContent bytes.Buffer
|
||||
// Create a limited reader to ensure we don't read more than the max size.
|
||||
lr := io.LimitReader(reader, int64(maxSize))
|
||||
|
||||
// Using io.CopyBuffer for performance advantages. Though buf is mandatory
|
||||
// for the method, due to the internal implementation of io.CopyBuffer, when
|
||||
// *bytes.Buffer implements io.WriterTo or io.ReaderFrom, the provided buf
|
||||
// is simply ignored. Thus, we can pass nil for the buf parameter.
|
||||
_, err = io.CopyBuffer(&fileContent, lr, nil)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return []byte{}, err
|
||||
return nil, err
|
||||
}
|
||||
a.size += bRead
|
||||
if len(fileChunk) > 0 {
|
||||
fileContent.Write(fileChunk[0:bRead])
|
||||
}
|
||||
if bRead < 512 {
|
||||
return fileContent.Bytes(), nil
|
||||
}
|
||||
if a.size >= maxSize && bRead == 512 {
|
||||
|
||||
if fileContent.Len() == maxSize {
|
||||
logger.V(2).Info("Max archive size reached.")
|
||||
return fileContent.Bytes(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return fileContent.Bytes(), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -127,6 +128,58 @@ func TestHandleFile(t *testing.T) {
|
|||
assert.Equal(t, 1, len(ch))
|
||||
}
|
||||
|
||||
func TestReadToMax(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "read full content within maxSize",
|
||||
input: []byte("abcdefg"),
|
||||
expected: []byte("abcdefg"),
|
||||
},
|
||||
{
|
||||
name: "read content larger than maxSize",
|
||||
input: make([]byte, maxSize+10), // this creates a byte slice 10 bytes larger than maxSize
|
||||
expected: make([]byte, maxSize),
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte(""),
|
||||
expected: []byte(""),
|
||||
},
|
||||
}
|
||||
|
||||
a := &Archive{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := bytes.NewReader(tt.input)
|
||||
output, err := a.ReadToMax(context.Background(), reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, tt.expected, output)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadToMax(b *testing.B) {
|
||||
data := bytes.Repeat([]byte("a"), 1024*1000) // 1MB of data.
|
||||
reader := bytes.NewReader(data)
|
||||
a := &Archive{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StartTimer()
|
||||
_, _ = a.ReadToMax(context.Background(), reader)
|
||||
b.StopTimer()
|
||||
|
||||
_, _ = reader.Seek(0, 0) // Reset the reader position.
|
||||
a.size = 0 // Reset archive size.
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDebContent(t *testing.T) {
|
||||
// Open the sample .deb file from the testdata folder.
|
||||
file, err := os.Open("testdata/test.deb")
|
||||
|
|
|
@ -133,7 +133,7 @@ func BranchHeads(repo *gogit.Repository) (map[string]*object.Commit, error) {
|
|||
}
|
||||
headCommit, err := repo.CommitObject(*headHash)
|
||||
if err != nil {
|
||||
logger.Error(err, "unable to get commit", "commit", headCommit.String())
|
||||
logger.Error(err, "unable to get commit", "head_hash", headHash.String())
|
||||
return nil
|
||||
}
|
||||
branches[branchName] = headCommit
|
||||
|
|
|
@ -20,7 +20,7 @@ const (
|
|||
|
||||
// Chunker takes a chunk and splits it into chunks of ChunkSize.
|
||||
func Chunker(originalChunk *Chunk) chan *Chunk {
|
||||
chunkChan := make(chan *Chunk)
|
||||
chunkChan := make(chan *Chunk, 1)
|
||||
go func() {
|
||||
defer close(chunkChan)
|
||||
if len(originalChunk.Data) <= TotalChunkSize {
|
||||
|
|
|
@ -238,6 +238,7 @@ func (s *Source) chunkAction(ctx context.Context, proj project, bld build, act a
|
|||
SourceType: s.Type(),
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
Data: removeCircleSha1Line(data.Bytes()),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Circleci{
|
||||
|
|
|
@ -119,6 +119,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
|
|||
SourceType: s.Type(),
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Docker{
|
||||
Docker: &source_metadatapb.Docker{
|
||||
|
|
|
@ -161,6 +161,7 @@ func (s *Source) scanFile(ctx context.Context, path string, chunksChan chan *sou
|
|||
SourceType: s.Type(),
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Filesystem{
|
||||
Filesystem: &source_metadatapb.Filesystem{
|
||||
|
@ -191,6 +192,7 @@ func (s *Source) scanFile(ctx context.Context, path string, chunksChan chan *sou
|
|||
SourceType: s.Type(),
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
Data: data.Bytes(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Filesystem{
|
||||
|
|
|
@ -324,6 +324,7 @@ func (s *Source) processObject(ctx context.Context, o object) error {
|
|||
chunkSkel := &sources.Chunk{
|
||||
SourceName: s.name,
|
||||
SourceType: s.Type(),
|
||||
JobID: s.JobID(),
|
||||
SourceID: s.sourceId,
|
||||
Verify: s.verify,
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -447,6 +446,7 @@ func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string
|
|||
chunkSkel := &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Verify: s.verify,
|
||||
|
@ -465,6 +465,7 @@ func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: diff.Content.Bytes(),
|
||||
|
@ -491,6 +492,7 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: append([]byte{}, newChunkBuffer.Bytes()...),
|
||||
|
@ -505,6 +507,7 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: line,
|
||||
|
@ -524,6 +527,7 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: append([]byte{}, newChunkBuffer.Bytes()...),
|
||||
|
@ -589,6 +593,7 @@ func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string,
|
|||
chunkSkel := &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Verify: s.verify,
|
||||
|
@ -603,6 +608,7 @@ func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string,
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.sourceName,
|
||||
SourceID: s.sourceID,
|
||||
JobID: s.jobID,
|
||||
SourceType: s.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: diff.Content.Bytes(),
|
||||
|
@ -693,28 +699,6 @@ func normalizeConfig(scanOptions *ScanOptions, repo *git.Repository) (err error)
|
|||
return nil
|
||||
}
|
||||
|
||||
// GenerateLink crafts a link to the specific file from a commit. This works in most major git providers (Github/Gitlab)
|
||||
func GenerateLink(repo, commit, file string, line int64) string {
|
||||
// bitbucket links are commits not commit...
|
||||
if strings.Contains(repo, "bitbucket.org/") {
|
||||
return repo[:len(repo)-4] + "/commits/" + commit
|
||||
}
|
||||
var link string
|
||||
if file == "" {
|
||||
link = repo[:len(repo)-4] + "/commit/" + commit
|
||||
} else {
|
||||
link = repo[:len(repo)-4] + "/blob/" + commit + "/" + file
|
||||
|
||||
// Both GitHub and Gitlab support hyperlinking to a specific line with #L<number>, e.g.:
|
||||
// https://github.com/trufflesecurity/trufflehog/blob/e856a6890d0da5a218f4f9283500b80043884641/go.mod#L169
|
||||
// https://gitlab.com/pdftk-java/pdftk/-/blob/88559a08f34175b6fae76c40a88f0377f64a12d7/java/com/gitlab/pdftk_java/report.java#L893
|
||||
if line > 0 && (strings.Contains(repo, "github") || strings.Contains(repo, "gitlab")) {
|
||||
link += "#L" + strconv.FormatInt(line, 10)
|
||||
}
|
||||
}
|
||||
return link
|
||||
}
|
||||
|
||||
func stripPassword(u string) (string, error) {
|
||||
if strings.HasPrefix(u, "git@") {
|
||||
return u, nil
|
||||
|
|
|
@ -151,55 +151,6 @@ func TestSource_Scan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_generateLink(t *testing.T) {
|
||||
type args struct {
|
||||
repo string
|
||||
commit string
|
||||
file string
|
||||
line int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "test link gen",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
file: ".gitignore",
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/blob/047b4a2ba42fc5b6c0bd535c5307434a666db5ec/.gitignore",
|
||||
},
|
||||
{
|
||||
name: "test link gen",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
file: ".gitignore",
|
||||
line: int64(4),
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/blob/047b4a2ba42fc5b6c0bd535c5307434a666db5ec/.gitignore#L4",
|
||||
},
|
||||
{
|
||||
name: "test link gen - no file",
|
||||
args: args{
|
||||
repo: "https://github.com/trufflesec-julian/confluence-go-api.git",
|
||||
commit: "047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
},
|
||||
want: "https://github.com/trufflesec-julian/confluence-go-api/commit/047b4a2ba42fc5b6c0bd535c5307434a666db5ec",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GenerateLink(tt.args.repo, tt.args.commit, tt.args.file, tt.args.line); got != tt.want {
|
||||
t.Errorf("generateLink() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We ran into an issue where upgrading a dependency caused the git patch chunking to break
|
||||
// So this test exists to make sure that when something changes, we know about it.
|
||||
func TestSource_Chunks_Integration(t *testing.T) {
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
|
@ -148,14 +149,16 @@ func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string)
|
|||
for _, ig := range include {
|
||||
g, err := glob.Compile(ig)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid include glob", "glob", g, "err", err)
|
||||
s.log.V(1).Info("invalid include glob", "include_value", ig, "err", err)
|
||||
continue
|
||||
}
|
||||
includeGlobs = append(includeGlobs, g)
|
||||
}
|
||||
for _, eg := range exclude {
|
||||
g, err := glob.Compile(eg)
|
||||
if err != nil {
|
||||
s.log.V(1).Info("invalid exclude glob", "glob", g, "err", err)
|
||||
s.log.V(1).Info("invalid exclude glob", "exclude_value", eg, "err", err)
|
||||
continue
|
||||
}
|
||||
excludeGlobs = append(excludeGlobs, g)
|
||||
}
|
||||
|
@ -261,7 +264,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64,
|
|||
File: sanitizer.UTF8(file),
|
||||
Email: sanitizer.UTF8(email),
|
||||
Repository: sanitizer.UTF8(repository),
|
||||
Link: git.GenerateLink(repository, commit, file, line),
|
||||
Link: giturl.GenerateLink(repository, commit, file, line),
|
||||
Timestamp: sanitizer.UTF8(timestamp),
|
||||
Line: line,
|
||||
Visibility: s.visibilityOf(aCtx, repository),
|
||||
|
@ -1222,6 +1225,7 @@ func (s *Source) chunkIssueComments(ctx context.Context, repo, repoPath string,
|
|||
chunk := &sources.Chunk{
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
SourceType: s.Type(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Github{
|
||||
|
@ -1255,6 +1259,7 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repo string, comm
|
|||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
SourceType: s.Type(),
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Github{
|
||||
Github: &source_metadatapb.Github{
|
||||
|
@ -1286,6 +1291,7 @@ func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments
|
|||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
SourceType: s.Type(),
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_Github{
|
||||
Github: &source_metadatapb.Github{
|
||||
|
|
|
@ -127,7 +127,7 @@ func (s *Source) Init(_ context.Context, name string, jobId, sourceId int64, ver
|
|||
File: sanitizer.UTF8(file),
|
||||
Email: sanitizer.UTF8(email),
|
||||
Repository: sanitizer.UTF8(repository),
|
||||
Link: git.GenerateLink(repository, commit, file, line),
|
||||
Link: giturl.GenerateLink(repository, commit, file, line),
|
||||
Timestamp: sanitizer.UTF8(timestamp),
|
||||
Line: line,
|
||||
},
|
||||
|
|
|
@ -40,9 +40,12 @@ type JobProgressHook interface {
|
|||
}
|
||||
|
||||
// JobProgressRef is a wrapper of a JobProgress for read-only access to its state.
|
||||
// If the job supports it, the reference can also be used to cancel running via
|
||||
// CancelRun.
|
||||
type JobProgressRef struct {
|
||||
SourceID int64
|
||||
JobID int64
|
||||
SourceID int64
|
||||
SourceName string
|
||||
jobProgress *JobProgress
|
||||
}
|
||||
|
||||
|
@ -65,6 +68,16 @@ func (r *JobProgressRef) Done() <-chan struct{} {
|
|||
return r.jobProgress.Done()
|
||||
}
|
||||
|
||||
// CancelRun requests that the job this is referencing is cancelled and stops
|
||||
// running. This method will have no effect if the job does not allow
|
||||
// cancellation.
|
||||
func (r *JobProgressRef) CancelRun(cause error) {
|
||||
if r.jobProgress == nil || r.jobProgress.jobCancel == nil {
|
||||
return
|
||||
}
|
||||
r.jobProgress.jobCancel(cause)
|
||||
}
|
||||
|
||||
// Fatal is a wrapper around error to differentiate non-fatal errors from fatal
|
||||
// ones. A fatal error is typically from a finished context or any error
|
||||
// returned from a source's Init, Chunks, Enumerate, or ChunkUnit methods.
|
||||
|
@ -88,14 +101,19 @@ func (f ChunkError) Unwrap() error { return f.err }
|
|||
// JobProgress aggregates information about a run of a Source.
|
||||
type JobProgress struct {
|
||||
// Unique identifiers for this job.
|
||||
SourceID int64
|
||||
JobID int64
|
||||
SourceID int64
|
||||
SourceName string
|
||||
// Tracks whether the job is finished or not.
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Requests to cancel the job.
|
||||
jobCancel context.CancelCauseFunc
|
||||
// Metrics.
|
||||
metrics JobProgressMetrics
|
||||
metricsLock sync.Mutex
|
||||
// Progress reported by the source.
|
||||
progress *Progress
|
||||
// Coarse grained hooks for adding extra functionality when events trigger.
|
||||
hooks []JobProgressHook
|
||||
}
|
||||
|
@ -111,8 +129,19 @@ type JobProgressMetrics struct {
|
|||
// Total number of chunks produced. This metric updates before the
|
||||
// chunk is sent on the output channel.
|
||||
TotalChunks uint64
|
||||
// All errors encountered.
|
||||
Errors []error
|
||||
// Set to true if the source supports enumeration and has finished
|
||||
// enumerating. If the source does not support enumeration, this field
|
||||
// is always false.
|
||||
DoneEnumerating bool
|
||||
|
||||
// Progress information reported by the source.
|
||||
SourcePercent int64
|
||||
SourceMessage string
|
||||
SourceEncodedResumeInfo string
|
||||
SourceSectionsCompleted int32
|
||||
SourceSectionsRemaining int32
|
||||
}
|
||||
|
||||
// WithHooks adds hooks to be called when an event triggers.
|
||||
|
@ -120,12 +149,18 @@ func WithHooks(hooks ...JobProgressHook) func(*JobProgress) {
|
|||
return func(jp *JobProgress) { jp.hooks = append(jp.hooks, hooks...) }
|
||||
}
|
||||
|
||||
// WithCancel allows cancelling the job by the JobProgressRef.
|
||||
func WithCancel(cancel context.CancelCauseFunc) func(*JobProgress) {
|
||||
return func(jp *JobProgress) { jp.jobCancel = cancel }
|
||||
}
|
||||
|
||||
// NewJobProgress creates a new job report for the given source and job ID.
|
||||
func NewJobProgress(sourceID, jobID int64, opts ...func(*JobProgress)) *JobProgress {
|
||||
func NewJobProgress(jobID, sourceID int64, sourceName string, opts ...func(*JobProgress)) *JobProgress {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
jp := &JobProgress{
|
||||
SourceID: sourceID,
|
||||
JobID: jobID,
|
||||
SourceID: sourceID,
|
||||
SourceName: sourceName,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
@ -135,6 +170,12 @@ func NewJobProgress(sourceID, jobID int64, opts ...func(*JobProgress)) *JobProgr
|
|||
return jp
|
||||
}
|
||||
|
||||
// TrackProgress informs the JobProgress of a Progress object and safely
|
||||
// exposes its information in the Snapshots.
|
||||
func (jp *JobProgress) TrackProgress(progress *Progress) {
|
||||
jp.progress = progress
|
||||
}
|
||||
|
||||
// executeHooks is a helper method to execute all the hooks for the given
|
||||
// closure.
|
||||
func (jp *JobProgress) executeHooks(todo func(hook JobProgressHook)) {
|
||||
|
@ -210,6 +251,16 @@ func (jp *JobProgress) Snapshot() JobProgressMetrics {
|
|||
metrics.Errors = make([]error, len(metrics.Errors))
|
||||
copy(metrics.Errors, jp.metrics.Errors)
|
||||
|
||||
if jp.progress != nil {
|
||||
jp.progress.mut.Lock()
|
||||
defer jp.progress.mut.Unlock()
|
||||
metrics.SourcePercent = jp.progress.PercentComplete
|
||||
metrics.SourceMessage = jp.progress.Message
|
||||
metrics.SourceEncodedResumeInfo = jp.progress.EncodedResumeInfo
|
||||
metrics.SourceSectionsCompleted = jp.progress.SectionsCompleted
|
||||
metrics.SourceSectionsRemaining = jp.progress.SectionsRemaining
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
|
@ -231,6 +282,7 @@ func (jp *JobProgress) Ref() JobProgressRef {
|
|||
return JobProgressRef{
|
||||
SourceID: jp.SourceID,
|
||||
JobID: jp.JobID,
|
||||
SourceName: jp.SourceName,
|
||||
jobProgress: jp,
|
||||
}
|
||||
}
|
||||
|
@ -280,7 +332,22 @@ func (m JobProgressMetrics) PercentComplete() int {
|
|||
num := m.FinishedUnits
|
||||
den := m.TotalUnits
|
||||
if num == 0 && den == 0 {
|
||||
return 0
|
||||
// Fallback to the source's self-reported percent complete if
|
||||
// the unit information isn't available.
|
||||
return int(m.SourcePercent)
|
||||
}
|
||||
return int(num * 100 / den)
|
||||
}
|
||||
|
||||
// ElapsedTime is a convenience method that provides the elapsed time the job
|
||||
// has been running. If it hasn't started yet, 0 is returned. If it has
|
||||
// finished, the total time is returned.
|
||||
func (m JobProgressMetrics) ElapsedTime() time.Duration {
|
||||
if m.StartTime.IsZero() {
|
||||
return 0
|
||||
}
|
||||
if m.EndTime.IsZero() {
|
||||
return time.Since(m.StartTime)
|
||||
}
|
||||
return m.EndTime.Sub(m.StartTime)
|
||||
}
|
||||
|
|
|
@ -36,10 +36,10 @@ func TestJobProgressFatalErrors(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestJobProgressRef(t *testing.T) {
|
||||
jp := NewJobProgress(123, 456)
|
||||
jp := NewJobProgress(123, 456, "source name")
|
||||
ref := jp.Ref()
|
||||
assert.Equal(t, int64(123), ref.SourceID)
|
||||
assert.Equal(t, int64(456), ref.JobID)
|
||||
assert.Equal(t, int64(123), ref.JobID)
|
||||
assert.Equal(t, int64(456), ref.SourceID)
|
||||
|
||||
// Test Done() blocks until Finish() is called.
|
||||
select {
|
||||
|
@ -61,7 +61,7 @@ func TestJobProgressHook(t *testing.T) {
|
|||
defer ctrl.Finish()
|
||||
|
||||
hook := NewMockJobProgressHook(ctrl)
|
||||
jp := NewJobProgress(123, 456, WithHooks(hook))
|
||||
jp := NewJobProgress(123, 456, "source name", WithHooks(hook))
|
||||
|
||||
// Start(JobProgressRef, time.Time)
|
||||
// End(JobProgressRef, time.Time)
|
||||
|
@ -115,3 +115,14 @@ func TestJobProgressDone(t *testing.T) {
|
|||
assert.FailNow(t, "done should not block for a nil job")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobProgressElapsedTime(t *testing.T) {
|
||||
metrics := JobProgressMetrics{}
|
||||
assert.Equal(t, time.Duration(0), metrics.ElapsedTime())
|
||||
|
||||
metrics.StartTime = time.Now()
|
||||
assert.Greater(t, metrics.ElapsedTime(), time.Duration(0))
|
||||
|
||||
metrics.EndTime = metrics.StartTime.Add(1 * time.Hour)
|
||||
assert.Equal(t, metrics.ElapsedTime(), 1*time.Hour)
|
||||
}
|
||||
|
|
|
@ -376,6 +376,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
|
|||
SourceType: s.Type(),
|
||||
SourceName: s.name,
|
||||
SourceID: s.SourceID(),
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: &source_metadatapb.MetaData{
|
||||
Data: &source_metadatapb.MetaData_S3{
|
||||
S3: &source_metadatapb.S3{
|
||||
|
|
|
@ -10,13 +10,12 @@ import (
|
|||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func TestSource_Chunks(t *testing.T) {
|
||||
|
@ -35,6 +34,7 @@ func TestSource_Chunks(t *testing.T) {
|
|||
name string
|
||||
verify bool
|
||||
connection *sourcespb.S3
|
||||
setEnv map[string]string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -58,6 +58,23 @@ func TestSource_Chunks(t *testing.T) {
|
|||
wantErr: false,
|
||||
wantChunkData: `W2RlZmF1bHRdCmF3c19hY2Nlc3Nfa2V5X2lkID0gQUtJQTM1T0hYMkRTT1pHNjQ3TkgKYXdzX3NlY3JldF9hY2Nlc3Nfa2V5ID0gUXk5OVMrWkIvQ1dsRk50eFBBaWQ3Z0d6dnNyWGhCQjd1ckFDQUxwWgpvdXRwdXQgPSBqc29uCnJlZ2lvbiA9IHVzLWVhc3QtMg==`,
|
||||
},
|
||||
{
|
||||
name: "gets chunks after assuming role",
|
||||
// This test will attempt to scan every bucket in the account, but the role policy blocks access to every
|
||||
// bucket except the one we want. This (expected behavior) causes errors in the test log output, but these
|
||||
// errors shouldn't actually cause test failures.
|
||||
init: init{
|
||||
connection: &sourcespb.S3{
|
||||
Roles: []string{"arn:aws:iam::619888638459:role/s3-test-assume-role"},
|
||||
},
|
||||
setEnv: map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": s3key,
|
||||
"AWS_SECRET_ACCESS_KEY": s3secret,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantChunkData: `W2RlZmF1bHRdCmF3c19zZWNyZXRfYWNjZXNzX2tleSA9IFF5OTlTK1pCL0NXbEZOdHhQQWlkN2dHenZzclhoQkI3dXJBQ0FMcFoKYXdzX2FjY2Vzc19rZXlfaWQgPSBBS0lBMzVPSFgyRFNPWkc2NDdOSApvdXRwdXQgPSBqc29uCnJlZ2lvbiA9IHVzLWVhc3QtMg==`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -65,6 +82,10 @@ func TestSource_Chunks(t *testing.T) {
|
|||
var cancelOnce sync.Once
|
||||
defer cancelOnce.Do(cancel)
|
||||
|
||||
for k, v := range tt.init.setEnv {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
s := Source{}
|
||||
conn, err := anypb.New(tt.init.connection)
|
||||
if err != nil {
|
||||
|
|
|
@ -37,6 +37,9 @@ type SourceManager struct {
|
|||
handlesLock sync.Mutex
|
||||
// Pool limiting the amount of concurrent sources running.
|
||||
pool errgroup.Group
|
||||
poolLimit int
|
||||
currentRunningCount int32
|
||||
// Max number of units to scan concurrently per source.
|
||||
concurrentUnits int
|
||||
// Run the sources using source unit enumeration / chunking if available.
|
||||
useSourceUnits bool
|
||||
|
@ -68,7 +71,10 @@ func WithReportHook(hook JobProgressHook) func(*SourceManager) {
|
|||
|
||||
// WithConcurrentSources limits the concurrent number of sources a manager can run.
|
||||
func WithConcurrentSources(concurrency int) func(*SourceManager) {
|
||||
return func(mgr *SourceManager) { mgr.pool.SetLimit(concurrency) }
|
||||
return func(mgr *SourceManager) {
|
||||
mgr.pool.SetLimit(concurrency)
|
||||
mgr.poolLimit = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
// WithBufferedOutput sets the size of the buffer used for the Chunks() channel.
|
||||
|
@ -151,20 +157,30 @@ func (s *SourceManager) asyncRun(ctx context.Context, handle handle) (JobProgres
|
|||
if err := s.preflightChecks(ctx, handle); err != nil {
|
||||
return JobProgressRef{}, err
|
||||
}
|
||||
// Get the name. Should never fail due to preflight checks.
|
||||
sourceInfo, ok := s.getSourceInfo(handle)
|
||||
if !ok {
|
||||
return JobProgressRef{SourceID: int64(handle)}, fmt.Errorf("unrecognized handle")
|
||||
}
|
||||
sourceName := sourceInfo.name
|
||||
// Get a Job ID.
|
||||
ctx = context.WithValue(ctx, "source_id", int64(handle))
|
||||
jobID, err := s.api.GetJobID(ctx, int64(handle))
|
||||
if err != nil {
|
||||
return JobProgressRef{SourceID: int64(handle)}, err
|
||||
return JobProgressRef{SourceID: int64(handle), SourceName: sourceName}, err
|
||||
}
|
||||
// Create a JobProgress object for tracking progress.
|
||||
progress := NewJobProgress(int64(handle), jobID, WithHooks(s.hooks...))
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
progress := NewJobProgress(jobID, int64(handle), sourceName, WithHooks(s.hooks...), WithCancel(cancel))
|
||||
s.pool.Go(func() error {
|
||||
atomic.AddInt32(&s.currentRunningCount, 1)
|
||||
defer atomic.AddInt32(&s.currentRunningCount, -1)
|
||||
ctx := context.WithValues(ctx,
|
||||
"job_id", jobID,
|
||||
"source_manager_worker_id", common.RandomID(5),
|
||||
)
|
||||
defer common.Recover(ctx)
|
||||
defer cancel(nil)
|
||||
return s.run(ctx, handle, jobID, progress)
|
||||
})
|
||||
return progress.Ref(), nil
|
||||
|
@ -199,6 +215,16 @@ func (s *SourceManager) ScanChunk(chunk *Chunk) {
|
|||
s.outputChunks <- chunk
|
||||
}
|
||||
|
||||
// AvailableCapacity returns the number of concurrent jobs the manager can
|
||||
// accommodate at this time. If there is no limit, -1 is returned.
|
||||
func (s *SourceManager) AvailableCapacity() int {
|
||||
if s.poolLimit == 0 {
|
||||
return -1
|
||||
}
|
||||
runCount := atomic.LoadInt32(&s.currentRunningCount)
|
||||
return s.poolLimit - int(runCount)
|
||||
}
|
||||
|
||||
// preflightChecks is a helper method to check the Manager or the context isn't
|
||||
// done and that the handle is valid.
|
||||
func (s *SourceManager) preflightChecks(ctx context.Context, handle handle) error {
|
||||
|
@ -221,6 +247,12 @@ func (s *SourceManager) run(ctx context.Context, handle handle, jobID int64, rep
|
|||
report.Start(time.Now())
|
||||
defer func() { report.End(time.Now()) }()
|
||||
|
||||
defer func() {
|
||||
if err := context.Cause(ctx); err != nil {
|
||||
report.ReportError(Fatal{err})
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize the source.
|
||||
sourceInfo, ok := s.getSourceInfo(handle)
|
||||
if !ok {
|
||||
|
@ -234,6 +266,7 @@ func (s *SourceManager) run(ctx context.Context, handle handle, jobID int64, rep
|
|||
report.ReportError(Fatal{err})
|
||||
return Fatal{err}
|
||||
}
|
||||
report.TrackProgress(source.GetProgress())
|
||||
ctx = context.WithValues(ctx,
|
||||
"source_type", source.Type().String(),
|
||||
"source_name", sourceInfo.name,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package sources
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
|
@ -280,7 +281,62 @@ func TestSourceManagerJobAndSourceIDs(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = mgr.Run(context.Background(), handle)
|
||||
ref, _ := mgr.Run(context.Background(), handle)
|
||||
assert.Equal(t, int64(1337), initializedSourceID)
|
||||
assert.Equal(t, int64(1337), ref.SourceID)
|
||||
assert.Equal(t, int64(9001), initializedJobID)
|
||||
assert.Equal(t, int64(9001), ref.JobID)
|
||||
assert.Equal(t, "dummy", ref.SourceName)
|
||||
}
|
||||
|
||||
// Chunk method that has a custom callback for the Chunks method.
|
||||
type callbackChunker struct {
|
||||
cb func(context.Context, chan *Chunk) error
|
||||
}
|
||||
|
||||
func (c callbackChunker) Chunks(ctx context.Context, ch chan *Chunk) error { return c.cb(ctx, ch) }
|
||||
func (c callbackChunker) Enumerate(context.Context, UnitReporter) error { return nil }
|
||||
func (c callbackChunker) ChunkUnit(context.Context, SourceUnit, ChunkReporter) error { return nil }
|
||||
|
||||
func TestSourceManagerCancelRun(t *testing.T) {
|
||||
mgr := NewManager(WithBufferedOutput(8))
|
||||
var returnedErr error
|
||||
handle, err := enrollDummy(mgr, callbackChunker{func(ctx context.Context, _ chan *Chunk) error {
|
||||
// The context passed to Chunks should get cancelled when ref.CancelRun() is called.
|
||||
<-ctx.Done()
|
||||
returnedErr = fmt.Errorf("oh no: %w", ctx.Err())
|
||||
return returnedErr
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
|
||||
ref, err := mgr.ScheduleRun(context.Background(), handle)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cancelErr := fmt.Errorf("abort! abort!")
|
||||
ref.CancelRun(cancelErr)
|
||||
<-ref.Done()
|
||||
assert.Error(t, ref.Snapshot().FatalError())
|
||||
assert.True(t, errors.Is(ref.Snapshot().FatalError(), returnedErr))
|
||||
assert.True(t, errors.Is(ref.Snapshot().FatalErrors(), cancelErr))
|
||||
}
|
||||
|
||||
func TestSourceManagerAvailableCapacity(t *testing.T) {
|
||||
mgr := NewManager(WithConcurrentSources(1337))
|
||||
start, end := make(chan struct{}), make(chan struct{})
|
||||
handle, err := enrollDummy(mgr, callbackChunker{func(context.Context, chan *Chunk) error {
|
||||
start <- struct{}{} // Send start signal.
|
||||
<-end // Wait for end signal.
|
||||
return nil
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1337, mgr.AvailableCapacity())
|
||||
ref, err := mgr.ScheduleRun(context.Background(), handle)
|
||||
assert.NoError(t, err)
|
||||
|
||||
<-start // Wait for start signal.
|
||||
assert.Equal(t, 1336, mgr.AvailableCapacity())
|
||||
end <- struct{}{} // Send end signal.
|
||||
<-ref.Done() // Wait for the job to finish.
|
||||
assert.Equal(t, 1337, mgr.AvailableCapacity())
|
||||
}
|
||||
|
|
|
@ -272,6 +272,7 @@ func (s *Source) monitorConnection(ctx context.Context, conn net.Conn, chunksCha
|
|||
SourceName: s.syslog.sourceName,
|
||||
SourceID: s.syslog.sourceID,
|
||||
SourceType: s.syslog.sourceType,
|
||||
JobID: s.JobID(),
|
||||
SourceMetadata: metadata,
|
||||
Data: input,
|
||||
Verify: s.verify,
|
||||
|
@ -313,6 +314,7 @@ func (s *Source) acceptUDPConnections(ctx context.Context, netListener net.Packe
|
|||
chunksChan <- &sources.Chunk{
|
||||
SourceName: s.syslog.sourceName,
|
||||
SourceID: s.syslog.sourceID,
|
||||
JobID: s.JobID(),
|
||||
SourceType: s.syslog.sourceType,
|
||||
SourceMetadata: metadata,
|
||||
Data: input,
|
||||
|
|
Loading…
Reference in a new issue