Validate S3 source (#1715)

This PR adds S3 source validation. This is accomplished by factoring out common "bucket visiting" logic to be used by both scanning and validation.
This commit is contained in:
Cody Rose 2023-09-05 10:18:58 -04:00 committed by GitHub
parent c9e6086644
commit afe708519b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 249 additions and 41 deletions

View file

@ -31,6 +31,7 @@ import (
)
const (
defaultAWSRegion = "us-east-1"
defaultMaxObjectSize = 250 * 1024 * 1024 // 250 MiB
maxObjectSizeLimit = 250 * 1024 * 1024 // 250 MiB
)
@ -53,6 +54,7 @@ type Source struct {
// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)
// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType {
@ -93,6 +95,23 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
return nil
}
func (s *Source) Validate(ctx context.Context) []error {
var errors []error
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
if len(roleErrs) > 0 {
errors = append(errors, roleErrs...)
}
}
err := s.visitRoles(ctx, visitor)
if err != nil {
errors = append(errors, err)
}
return errors
}
// setMaxObjectSize sets the maximum size of objects that will be scanned. If
// not set, set to a negative number, or set larger than the
// maxObjectSizeLimit, the defaultMaxObjectSizeLimit will be used.
@ -153,7 +172,7 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
res, err := client.ListBuckets(&s3.ListBucketsInput{})
if err != nil {
return nil, fmt.Errorf("could not list s3 buckets: %w", err)
return nil, err
}
var bucketsToScan []string
@ -163,32 +182,28 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
return bucketsToScan, nil
}
func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) error {
const defaultAWSRegion = "us-east-1"
func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) {
objectCount := uint64(0)
logger := s.log
if role != "" {
logger = logger.WithValues("roleArn", role)
}
for i, bucket := range bucketsToScan {
logger := logger.WithValues("bucket", bucket)
if common.IsDone(ctx) {
return nil
return
}
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
s.log.Info("Scanning bucket", "bucket", bucket)
region, err := s3manager.GetBucketRegionWithClient(ctx, client, bucket)
if err != nil {
s.log.Error(err, "could not get s3 region for bucket", "bucket: ", bucket)
continue
}
logger.Info("Scanning bucket")
var regionalClient *s3.S3
if region != defaultAWSRegion {
regionalClient, err = s.newClient(region, role)
if err != nil {
s.log.Error(err, "could not make regional s3 client")
continue
}
} else {
regionalClient = client
regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
logger.Error(err, "could not get regional client for bucket")
continue
}
errorCount := sync.Map{}
@ -201,40 +216,46 @@ func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bu
})
if err != nil {
s.log.Error(err, "could not list objects in s3 bucket", "bucket: ", bucket)
continue
if role == "" {
logger.Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
logger.V(3).Info("could not list objects in bucket",
"err", err)
}
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
return nil
}
// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
const defaultAWSRegion = "us-east-1"
roles := s.conn.Roles
if len(roles) == 0 {
roles = []string{""}
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
}
for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
}
return s.visitRoles(ctx, visitor)
}
bucketsToScan, err := s.getBucketsToScan(client)
if err != nil {
return err
}
if err := s.scanBuckets(ctx, client, role, bucketsToScan, chunksChan); err != nil {
return err
}
func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) {
region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket)
if err != nil {
return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0)
}
return nil
if region == defaultAWSRegion {
return defaultRegionClient, nil
}
regionalClient, err := s.newClient(region, role)
if err != nil {
return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0)
}
return regionalClient, nil
}
// pageChunker emits chunks onto the given channel from a page
@ -396,6 +417,65 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
_ = s.jobPool.Wait()
}
func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleArn string, buckets []string) []error {
shouldHaveAccessToAllBuckets := roleArn == ""
wasAbleToListAnyBucket := false
var errors []error
for _, bucket := range buckets {
if common.IsDone(ctx) {
return append(errors, ctx.Err())
}
regionalClient, err := s.getRegionalClientForBucket(ctx, client, roleArn, bucket)
if err != nil {
errors = append(errors, fmt.Errorf("could not get regional client for bucket %q: %w", bucket, err))
continue
}
_, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket})
if err == nil {
wasAbleToListAnyBucket = true
} else if shouldHaveAccessToAllBuckets {
errors = append(errors, fmt.Errorf("could not list objects in bucket %q: %w", bucket, err))
}
}
if !wasAbleToListAnyBucket {
if roleArn == "" {
errors = append(errors, fmt.Errorf("could not list objects in any bucket"))
} else {
errors = append(errors, fmt.Errorf("role %q could not list objects in any bucket", roleArn))
}
}
return errors
}
func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error {
roles := s.conn.Roles
if len(roles) == 0 {
roles = []string{""}
}
for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
}
bucketsToScan, err := s.getBucketsToScan(client)
if err != nil {
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
}
f(ctx, client, role, bucketsToScan)
}
return nil
}
// S3 links currently have the general format of:
// https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key]
func makeS3Link(bucket, region, key string) string {

View file

@ -4,10 +4,14 @@
package s3
import (
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"google.golang.org/protobuf/types/known/anypb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
@ -45,3 +49,127 @@ func TestSource_ChunksCount(t *testing.T) {
}
assert.Greater(t, got, wantChunkCount)
}
func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}
s3key := secret.MustGetField("AWS_S3_KEY")
s3secret := secret.MustGetField("AWS_S3_SECRET")
tests := []struct {
name string
roles []string
buckets []string
wantErrCount int
}{
{
name: "buckets without roles, can access all buckets",
buckets: []string{
"truffletestbucket-s3-tests",
},
wantErrCount: 0,
},
{
name: "buckets without roles, one error per inaccessible bucket",
buckets: []string{
"truffletestbucket-s3-tests",
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 2,
},
{
name: "roles without buckets, all can access at least one account bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
wantErrCount: 0,
},
{
name: "roles without buckets, one error per role that cannot access any account buckets",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
"arn:aws:iam::619888638459:role/test-no-access",
},
wantErrCount: 1,
},
{
name: "role and buckets, can access at least one bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 0,
},
{
name: "roles and buckets, one error per role that cannot access at least one bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
"arn:aws:iam::619888638459:role/test-no-access",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 1,
},
{
name: "role and buckets, a bucket doesn't even exist",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"not-a-real-bucket-asljdhmglasjgvklhsdaljfh", // need a bucket name that nobody is likely to ever create
},
wantErrCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
var cancelOnce sync.Once
defer cancelOnce.Do(cancel)
// These are used by the tests that assume roles
t.Setenv("AWS_ACCESS_KEY_ID", s3key)
t.Setenv("AWS_SECRET_ACCESS_KEY", s3secret)
s := &Source{}
conn, err := anypb.New(&sourcespb.S3{
// These are used by the tests that don't assume roles
Credential: &sourcespb.S3_AccessKey{
AccessKey: &credentialspb.KeySecret{
Key: s3key,
Secret: s3secret,
},
},
Buckets: tt.buckets,
Roles: tt.roles,
})
if err != nil {
t.Fatal(err)
}
err = s.Init(ctx, tt.name, 0, 0, false, conn, 0)
if err != nil {
t.Fatal(err)
}
errs := s.Validate(ctx)
assert.Equal(t, tt.wantErrCount, len(errs))
})
}
}