mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Validate S3 source (#1715)
This PR adds S3 source validation. This is accomplished by factoring out common "bucket visiting" logic to be used by both scanning and validation.
This commit is contained in:
parent
c9e6086644
commit
afe708519b
2 changed files with 249 additions and 41 deletions
|
@ -31,6 +31,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
defaultAWSRegion = "us-east-1"
|
||||
defaultMaxObjectSize = 250 * 1024 * 1024 // 250 MiB
|
||||
maxObjectSizeLimit = 250 * 1024 * 1024 // 250 MiB
|
||||
)
|
||||
|
@ -53,6 +54,7 @@ type Source struct {
|
|||
// Ensure the Source satisfies the interfaces at compile time
|
||||
var _ sources.Source = (*Source)(nil)
|
||||
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
|
||||
var _ sources.Validator = (*Source)(nil)
|
||||
|
||||
// Type returns the type of source
|
||||
func (s *Source) Type() sourcespb.SourceType {
|
||||
|
@ -93,6 +95,23 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) Validate(ctx context.Context) []error {
|
||||
var errors []error
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
|
||||
if len(roleErrs) > 0 {
|
||||
errors = append(errors, roleErrs...)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.visitRoles(ctx, visitor)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// setMaxObjectSize sets the maximum size of objects that will be scanned. If
|
||||
// not set, set to a negative number, or set larger than the
|
||||
// maxObjectSizeLimit, the defaultMaxObjectSizeLimit will be used.
|
||||
|
@ -153,7 +172,7 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
|
|||
|
||||
res, err := client.ListBuckets(&s3.ListBucketsInput{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not list s3 buckets: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bucketsToScan []string
|
||||
|
@ -163,32 +182,28 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
|
|||
return bucketsToScan, nil
|
||||
}
|
||||
|
||||
func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) error {
|
||||
const defaultAWSRegion = "us-east-1"
|
||||
func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) {
|
||||
objectCount := uint64(0)
|
||||
|
||||
logger := s.log
|
||||
if role != "" {
|
||||
logger = logger.WithValues("roleArn", role)
|
||||
}
|
||||
|
||||
for i, bucket := range bucketsToScan {
|
||||
logger := logger.WithValues("bucket", bucket)
|
||||
|
||||
if common.IsDone(ctx) {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
|
||||
s.log.Info("Scanning bucket", "bucket", bucket)
|
||||
region, err := s3manager.GetBucketRegionWithClient(ctx, client, bucket)
|
||||
if err != nil {
|
||||
s.log.Error(err, "could not get s3 region for bucket", "bucket: ", bucket)
|
||||
continue
|
||||
}
|
||||
logger.Info("Scanning bucket")
|
||||
|
||||
var regionalClient *s3.S3
|
||||
if region != defaultAWSRegion {
|
||||
regionalClient, err = s.newClient(region, role)
|
||||
if err != nil {
|
||||
s.log.Error(err, "could not make regional s3 client")
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
regionalClient = client
|
||||
regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
|
||||
if err != nil {
|
||||
logger.Error(err, "could not get regional client for bucket")
|
||||
continue
|
||||
}
|
||||
|
||||
errorCount := sync.Map{}
|
||||
|
@ -201,40 +216,46 @@ func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bu
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
s.log.Error(err, "could not list objects in s3 bucket", "bucket: ", bucket)
|
||||
continue
|
||||
if role == "" {
|
||||
logger.Error(err, "could not list objects in bucket")
|
||||
} else {
|
||||
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
|
||||
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
|
||||
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
|
||||
// and therefore not an error.
|
||||
logger.V(3).Info("could not list objects in bucket",
|
||||
"err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chunks emits chunks of bytes over a channel.
|
||||
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
|
||||
const defaultAWSRegion = "us-east-1"
|
||||
|
||||
roles := s.conn.Roles
|
||||
if len(roles) == 0 {
|
||||
roles = []string{""}
|
||||
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
|
||||
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
client, err := s.newClient(defaultAWSRegion, role)
|
||||
if err != nil {
|
||||
return errors.WrapPrefix(err, "could not create s3 client", 0)
|
||||
}
|
||||
return s.visitRoles(ctx, visitor)
|
||||
}
|
||||
|
||||
bucketsToScan, err := s.getBucketsToScan(client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.scanBuckets(ctx, client, role, bucketsToScan, chunksChan); err != nil {
|
||||
return err
|
||||
}
|
||||
func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) {
|
||||
region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket)
|
||||
if err != nil {
|
||||
return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
if region == defaultAWSRegion {
|
||||
return defaultRegionClient, nil
|
||||
}
|
||||
|
||||
regionalClient, err := s.newClient(region, role)
|
||||
if err != nil {
|
||||
return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0)
|
||||
}
|
||||
|
||||
return regionalClient, nil
|
||||
}
|
||||
|
||||
// pageChunker emits chunks onto the given channel from a page
|
||||
|
@ -396,6 +417,65 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
|
|||
_ = s.jobPool.Wait()
|
||||
}
|
||||
|
||||
func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleArn string, buckets []string) []error {
|
||||
shouldHaveAccessToAllBuckets := roleArn == ""
|
||||
wasAbleToListAnyBucket := false
|
||||
var errors []error
|
||||
|
||||
for _, bucket := range buckets {
|
||||
if common.IsDone(ctx) {
|
||||
return append(errors, ctx.Err())
|
||||
}
|
||||
|
||||
regionalClient, err := s.getRegionalClientForBucket(ctx, client, roleArn, bucket)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("could not get regional client for bucket %q: %w", bucket, err))
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket})
|
||||
|
||||
if err == nil {
|
||||
wasAbleToListAnyBucket = true
|
||||
} else if shouldHaveAccessToAllBuckets {
|
||||
errors = append(errors, fmt.Errorf("could not list objects in bucket %q: %w", bucket, err))
|
||||
}
|
||||
}
|
||||
|
||||
if !wasAbleToListAnyBucket {
|
||||
if roleArn == "" {
|
||||
errors = append(errors, fmt.Errorf("could not list objects in any bucket"))
|
||||
} else {
|
||||
errors = append(errors, fmt.Errorf("role %q could not list objects in any bucket", roleArn))
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error {
|
||||
roles := s.conn.Roles
|
||||
if len(roles) == 0 {
|
||||
roles = []string{""}
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
client, err := s.newClient(defaultAWSRegion, role)
|
||||
if err != nil {
|
||||
return errors.WrapPrefix(err, "could not create s3 client", 0)
|
||||
}
|
||||
|
||||
bucketsToScan, err := s.getBucketsToScan(client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
|
||||
}
|
||||
|
||||
f(ctx, client, role, bucketsToScan)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// S3 links currently have the general format of:
|
||||
// https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key]
|
||||
func makeS3Link(bucket, region, key string) string {
|
||||
|
|
|
@ -4,10 +4,14 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
|
@ -45,3 +49,127 @@ func TestSource_ChunksCount(t *testing.T) {
|
|||
}
|
||||
assert.Greater(t, got, wantChunkCount)
|
||||
}
|
||||
|
||||
func TestSource_Validate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
secret, err := common.GetTestSecret(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
|
||||
}
|
||||
|
||||
s3key := secret.MustGetField("AWS_S3_KEY")
|
||||
s3secret := secret.MustGetField("AWS_S3_SECRET")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roles []string
|
||||
buckets []string
|
||||
wantErrCount int
|
||||
}{
|
||||
{
|
||||
name: "buckets without roles, can access all buckets",
|
||||
buckets: []string{
|
||||
"truffletestbucket-s3-tests",
|
||||
},
|
||||
wantErrCount: 0,
|
||||
},
|
||||
{
|
||||
name: "buckets without roles, one error per inaccessible bucket",
|
||||
buckets: []string{
|
||||
"truffletestbucket-s3-tests",
|
||||
"truffletestbucket-s3-role-assumption",
|
||||
"truffletestbucket-no-access",
|
||||
},
|
||||
wantErrCount: 2,
|
||||
},
|
||||
{
|
||||
name: "roles without buckets, all can access at least one account bucket",
|
||||
roles: []string{
|
||||
"arn:aws:iam::619888638459:role/s3-test-assume-role",
|
||||
},
|
||||
wantErrCount: 0,
|
||||
},
|
||||
{
|
||||
name: "roles without buckets, one error per role that cannot access any account buckets",
|
||||
roles: []string{
|
||||
"arn:aws:iam::619888638459:role/s3-test-assume-role",
|
||||
"arn:aws:iam::619888638459:role/test-no-access",
|
||||
},
|
||||
wantErrCount: 1,
|
||||
},
|
||||
{
|
||||
name: "role and buckets, can access at least one bucket",
|
||||
roles: []string{
|
||||
"arn:aws:iam::619888638459:role/s3-test-assume-role",
|
||||
},
|
||||
buckets: []string{
|
||||
"truffletestbucket-s3-role-assumption",
|
||||
"truffletestbucket-no-access",
|
||||
},
|
||||
wantErrCount: 0,
|
||||
},
|
||||
{
|
||||
name: "roles and buckets, one error per role that cannot access at least one bucket",
|
||||
roles: []string{
|
||||
"arn:aws:iam::619888638459:role/s3-test-assume-role",
|
||||
"arn:aws:iam::619888638459:role/test-no-access",
|
||||
},
|
||||
buckets: []string{
|
||||
"truffletestbucket-s3-role-assumption",
|
||||
"truffletestbucket-no-access",
|
||||
},
|
||||
wantErrCount: 1,
|
||||
},
|
||||
{
|
||||
name: "role and buckets, a bucket doesn't even exist",
|
||||
roles: []string{
|
||||
"arn:aws:iam::619888638459:role/s3-test-assume-role",
|
||||
},
|
||||
buckets: []string{
|
||||
"truffletestbucket-s3-role-assumption",
|
||||
"not-a-real-bucket-asljdhmglasjgvklhsdaljfh", // need a bucket name that nobody is likely to ever create
|
||||
},
|
||||
wantErrCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
|
||||
var cancelOnce sync.Once
|
||||
defer cancelOnce.Do(cancel)
|
||||
|
||||
// These are used by the tests that assume roles
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", s3key)
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", s3secret)
|
||||
|
||||
s := &Source{}
|
||||
|
||||
conn, err := anypb.New(&sourcespb.S3{
|
||||
// These are used by the tests that don't assume roles
|
||||
Credential: &sourcespb.S3_AccessKey{
|
||||
AccessKey: &credentialspb.KeySecret{
|
||||
Key: s3key,
|
||||
Secret: s3secret,
|
||||
},
|
||||
},
|
||||
Buckets: tt.buckets,
|
||||
Roles: tt.roles,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = s.Init(ctx, tt.name, 0, 0, false, conn, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
errs := s.Validate(ctx)
|
||||
|
||||
assert.Equal(t, tt.wantErrCount, len(errs))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue