package s3 import ( "context" "encoding/base64" "fmt" "os" "sync" "testing" "time" "github.com/kylelemons/godebug/pretty" log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) func TestSource_Chunks(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() secret, err := common.GetTestSecret(ctx) if err != nil { t.Fatal(fmt.Errorf("failed to access secret: %v", err)) } s3key := secret.MustGetField("AWS_S3_KEY") s3secret := secret.MustGetField("AWS_S3_SECRET") type init struct { name string verify bool connection *sourcespb.S3 } tests := []struct { name string init init wantErr bool wantChunkData string }{ { name: "gets chunks", init: init{ connection: &sourcespb.S3{ Credential: &sourcespb.S3_AccessKey{ AccessKey: &credentialspb.KeySecret{ Key: s3key, Secret: s3secret, }, }, Buckets: []string{"thog-tmp-test"}, }, }, wantErr: false, wantChunkData: `W2RlZmF1bHRdCmF3c19hY2Nlc3Nfa2V5X2lkID0gQUtJQTM1T0hYMkRTT1pHNjQ3TkgKYXdzX3NlY3JldF9hY2Nlc3Nfa2V5ID0gUXk5OVMrWkIvQ1dsRk50eFBBaWQ3Z0d6dnNyWGhCQjd1ckFDQUxwWgpvdXRwdXQgPSBqc29uCnJlZ2lvbiA9IHVzLWVhc3QtMg==`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log.SetFormatter(&log.TextFormatter{ForceColors: true}) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) var cancelOnce sync.Once defer cancelOnce.Do(cancel) s := Source{} log.SetLevel(log.DebugLevel) conn, err := anypb.New(tt.init.connection) if err != nil { t.Fatal(err) } err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return } chunksCh := make(chan *sources.Chunk, 100) go func() { err = s.Chunks(ctx, chunksCh) if (err != nil) != tt.wantErr { t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr) os.Exit(1) } }() gotChunk := <-chunksCh wantData, _ := base64.StdEncoding.DecodeString(tt.wantChunkData) if diff := pretty.Compare(gotChunk.Data, wantData); diff != "" { t.Errorf("%s: Source.Chunks() diff: (-got +want)\n%s", tt.name, diff) } }) } }