trufflehog/pkg/sources/filesystem/filesystem_test.go
Miccah c60443891b
Add Display method to SourceUnit and Kind member to the CommonSourceUnit (#2450)
* Add Display method to SourceUnit and Kind member to the CommonSourceUnit

* Make SourceUnitID return the ID and a kind

These two values together uniquely represent a unit.
2024-02-20 11:24:13 -08:00

339 lines
9 KiB
Go

package filesystem
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/go-logr/logr"
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
)
func TestSource_Scan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
type init struct {
name string
verify bool
connection *sourcespb.Filesystem
}
tests := []struct {
name string
init init
wantSourceMetadata *source_metadatapb.MetaData
wantErr bool
}{
{
name: "get a chunk",
init: init{
name: "this repo",
connection: &sourcespb.Filesystem{
Paths: []string{"."},
},
verify: true,
},
wantSourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Filesystem{
Filesystem: &source_metadatapb.Filesystem{
File: "filesystem.go",
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := Source{}
conn, err := anypb.New(tt.init.connection)
if err != nil {
t.Fatal(err)
}
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return
}
chunksCh := make(chan *sources.Chunk, 1)
// TODO: this is kind of bad, if it errors right away we don't see it as a test failure.
// Debugging this usually requires setting a breakpoint on L78 and running test w/ debug.
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
return
}
}()
var counter int
for chunk := range chunksCh {
if chunk.SourceMetadata.GetFilesystem().GetFile() == "filesystem.go" {
counter++
if diff := pretty.Compare(chunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
}
}
}
assert.Equal(t, 1, counter)
})
}
}
func TestScanFile(t *testing.T) {
chunkSize := sources.ChunkSize
secretPart1 := "SECRET"
secretPart2 := "SPLIT"
// Split the secret into two parts and pad the rest of the chunk with A's.
data := strings.Repeat("A", chunkSize-len(secretPart1)) + secretPart1 + secretPart2 + strings.Repeat("A", chunkSize-len(secretPart2))
tmpfile, cleanup, err := createTempFile("", data)
assert.Nil(t, err)
defer cleanup()
source := &Source{}
chunksChan := make(chan *sources.Chunk, 2)
ctx := context.WithLogger(context.Background(), logr.Discard())
go func() {
defer close(chunksChan)
err = source.scanFile(ctx, tmpfile.Name(), chunksChan)
assert.Nil(t, err)
}()
// Read from the channel and validate the secrets.
foundSecret := ""
for chunkCh := range chunksChan {
foundSecret += string(chunkCh.Data)
}
assert.Contains(t, foundSecret, secretPart1+secretPart2)
}
func TestEnumerate(t *testing.T) {
// TODO: refactor to allow a virtual filesystem.
t.Parallel()
ctx := context.Background()
// Setup the connection to test enumeration.
dir, err := os.MkdirTemp("", "trufflehog-test-enumerate")
assert.NoError(t, err)
defer os.RemoveAll(dir)
units := []string{
"/one", "/two", "/three",
"/path/to/dir/", "/path/to/another/dir/",
}
// Prefix the units with the tempdir and create files on disk.
for i, unit := range units {
fullPath := filepath.Join(dir, unit)
units[i] = fullPath
if i < 3 {
f, err := os.Create(fullPath)
assert.NoError(t, err)
f.Close()
} else {
assert.NoError(t, os.MkdirAll(fullPath, 0755))
// Create a file in the directory for enumeration to find.
f, err := os.CreateTemp(fullPath, "file")
assert.NoError(t, err)
units[i] = f.Name()
f.Close()
}
}
conn, err := anypb.New(&sourcespb.Filesystem{
Paths: units[0:3],
Directories: units[3:],
})
assert.NoError(t, err)
// Initialize the source.
s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1)
assert.NoError(t, err)
reporter := sourcestest.TestReporter{}
err = s.Enumerate(ctx, &reporter)
assert.NoError(t, err)
assert.Equal(t, len(units), len(reporter.Units))
assert.Equal(t, 0, len(reporter.UnitErrs))
for _, unit := range reporter.Units {
path, _ := unit.SourceUnitID()
assert.Contains(t, units, path)
}
for _, unit := range units {
assert.Contains(t, reporter.Units, sources.CommonSourceUnit{ID: unit})
}
}
func TestChunkUnit(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Setup test file to chunk.
fileContents := "TestChunkUnit"
tmpfile, cleanup, err := createTempFile("", fileContents)
assert.NoError(t, err)
defer cleanup()
tmpdir, cleanup, err := createTempDir("", "foo", "bar", "baz")
assert.NoError(t, err)
defer cleanup()
conn, err := anypb.New(&sourcespb.Filesystem{})
assert.NoError(t, err)
// Initialize the source.
s := Source{}
err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1)
assert.NoError(t, err)
// Happy path single file.
reporter := sourcestest.TestReporter{}
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpfile.Name(),
}, &reporter)
assert.NoError(t, err)
// Happy path directory.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpdir,
}, &reporter)
assert.NoError(t, err)
// Error path.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: "/file/not/found",
}, &reporter)
assert.NoError(t, err)
assert.Equal(t, 4, len(reporter.Chunks))
assert.Equal(t, 1, len(reporter.ChunkErrs))
dataFound := make(map[string]struct{}, 4)
for _, chunk := range reporter.Chunks {
dataFound[string(chunk.Data)] = struct{}{}
}
assert.Contains(t, dataFound, fileContents)
assert.Contains(t, dataFound, "foo")
assert.Contains(t, dataFound, "bar")
assert.Contains(t, dataFound, "baz")
}
func TestEnumerateReporterErr(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Setup the connection to test enumeration.
units := []string{
"/one", "/two", "/three",
"/path/to/dir/", "/path/to/another/dir/",
}
conn, err := anypb.New(&sourcespb.Filesystem{
Paths: units[0:3],
Directories: units[3:],
})
assert.NoError(t, err)
// Initialize the source.
s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1)
assert.NoError(t, err)
// Enumerate should always return an error if the reporter returns an
// error.
reporter := sourcestest.ErrReporter{}
err = s.Enumerate(ctx, &reporter)
assert.Error(t, err)
}
func TestChunkUnitReporterErr(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Setup test file to chunk.
tmpfile, err := os.CreateTemp("", "example.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
fileContents := []byte("TestChunkUnit")
_, err = tmpfile.Write(fileContents)
assert.NoError(t, err)
assert.NoError(t, tmpfile.Close())
conn, err := anypb.New(&sourcespb.Filesystem{})
assert.NoError(t, err)
// Initialize the source.
s := Source{}
err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1)
assert.NoError(t, err)
// Happy path. ChunkUnit should always return an error if the reporter
// returns an error.
reporter := sourcestest.ErrReporter{}
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpfile.Name(),
}, &reporter)
assert.Error(t, err)
// Error path. ChunkUnit should always return an error if the reporter
// returns an error.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: "/file/not/found",
}, &reporter)
assert.Error(t, err)
}
// createTempFile is a helper function to create a temporary file in the given
// directory with the provided contents. If dir is "", the operating system's
// temp directory is used.
func createTempFile(dir string, contents string) (*os.File, func(), error) {
tmpfile, err := os.CreateTemp(dir, "trufflehogtest")
if err != nil {
return nil, nil, err
}
if _, err := tmpfile.Write([]byte(contents)); err != nil {
_ = os.Remove(tmpfile.Name())
return nil, nil, err
}
if err := tmpfile.Close(); err != nil {
_ = os.Remove(tmpfile.Name())
return nil, nil, err
}
return tmpfile, func() { _ = os.Remove(tmpfile.Name()) }, nil
}
// createTempDir is a helper function to create a temporary directory in the
// given directory with files containing the provided contents. If dir is "",
// the operating system's temp directory is used.
func createTempDir(dir string, contents ...string) (string, func(), error) {
tmpdir, err := os.MkdirTemp(dir, "trufflehogtest")
if err != nil {
return "", nil, err
}
for _, content := range contents {
if _, _, err := createTempFile(tmpdir, content); err != nil {
_ = os.RemoveAll(tmpdir)
return "", nil, err
}
}
return tmpdir, func() { _ = os.RemoveAll(tmpdir) }, nil
}