trufflehog/pkg/sources/s3/s3_test.go
ahrav 8be89a593b
Handle errors in a thread safe manner (#1052)
* Handle errors in a thread safe manner.

* fix test.

* fix linter.

* address comments.
2023-02-02 11:05:33 -08:00

101 lines
2.7 KiB
Go

package s3
import (
"encoding/base64"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/kylelemons/godebug/pretty"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
func TestSource_Chunks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}
s3key := secret.MustGetField("AWS_S3_KEY")
s3secret := secret.MustGetField("AWS_S3_SECRET")
type init struct {
name string
verify bool
connection *sourcespb.S3
}
tests := []struct {
name string
init init
wantErr bool
wantChunkData string
}{
{
name: "gets chunks",
init: init{
connection: &sourcespb.S3{
Credential: &sourcespb.S3_AccessKey{
AccessKey: &credentialspb.KeySecret{
Key: s3key,
Secret: s3secret,
},
},
Buckets: []string{"thog-tmp-test"},
},
},
wantErr: false,
wantChunkData: `W2RlZmF1bHRdCmF3c19hY2Nlc3Nfa2V5X2lkID0gQUtJQTM1T0hYMkRTT1pHNjQ3TkgKYXdzX3NlY3JldF9hY2Nlc3Nfa2V5ID0gUXk5OVMrWkIvQ1dsRk50eFBBaWQ3Z0d6dnNyWGhCQjd1ckFDQUxwWgpvdXRwdXQgPSBqc29uCnJlZ2lvbiA9IHVzLWVhc3QtMg==`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log.SetFormatter(&log.TextFormatter{ForceColors: true})
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
var cancelOnce sync.Once
defer cancelOnce.Do(cancel)
s := Source{}
log.SetLevel(log.DebugLevel)
conn, err := anypb.New(tt.init.connection)
if err != nil {
t.Fatal(err)
}
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return
}
chunksCh := make(chan *sources.Chunk)
go func() {
err = s.Chunks(ctx, chunksCh)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
os.Exit(1)
}
}()
gotChunk := <-chunksCh
wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData)
if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" {
t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff)
}
assert.Equal(t, "", s.GetProgress().EncodedResumeInfo)
assert.Equal(t, int64(100), s.GetProgress().PercentComplete)
})
}
}