[THOG-332 ]Remove TokenSource interface from the init method of Source. (#539)

* Remove TokenSource interface from the init method of Source.

* Remove proto message.

* Remove proto message.

* Fix tests.

* Fix filesystem test.
This commit is contained in:
ahrav 2022-05-13 21:35:06 +00:00 committed by GitHub
parent 928b3b4d28
commit d2605354fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 508 additions and 974 deletions

View file

@ -23,7 +23,7 @@ func (e *Engine) ScanFileSystem(ctx context.Context, directories []string) error
} }
fileSystemSource := filesystem.Source{} fileSystemSource := filesystem.Source{}
err = fileSystemSource.Init(ctx, "trufflehog - filesystem", 0, int64(sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM), true, &conn, runtime.NumCPU(), nil) err = fileSystemSource.Init(ctx, "trufflehog - filesystem", 0, int64(sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM), true, &conn, runtime.NumCPU())
if err != nil { if err != nil {
return errors.WrapPrefix(err, "could not init filesystem source", 0) return errors.WrapPrefix(err, "could not init filesystem source", 0)
} }

View file

@ -34,7 +34,7 @@ func (e *Engine) ScanGitHub(ctx context.Context, endpoint string, repos, orgs []
logrus.WithError(err).Error("failed to marshal github connection") logrus.WithError(err).Error("failed to marshal github connection")
return err return err
} }
err = source.Init(ctx, "trufflehog - github", 0, 0, false, &conn, concurrency, nil) err = source.Init(ctx, "trufflehog - github", 0, 0, false, &conn, concurrency)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to initialize github source") logrus.WithError(err).Error("failed to initialize github source")
return err return err

View file

@ -40,7 +40,7 @@ func (e *Engine) ScanGitLab(ctx context.Context, endpoint, token string, reposit
} }
gitlabSource := gitlab.Source{} gitlabSource := gitlab.Source{}
err = gitlabSource.Init(ctx, "trufflehog - gitlab", 0, int64(sourcespb.SourceType_SOURCE_TYPE_GITLAB), true, &conn, runtime.NumCPU(), nil) err = gitlabSource.Init(ctx, "trufflehog - gitlab", 0, int64(sourcespb.SourceType_SOURCE_TYPE_GITLAB), true, &conn, runtime.NumCPU())
if err != nil { if err != nil {
return errors.WrapPrefix(err, "could not init GitLab source", 0) return errors.WrapPrefix(err, "could not init GitLab source", 0)
} }

View file

@ -42,7 +42,7 @@ func (e *Engine) ScanS3(ctx context.Context, key, secret string, cloudCred bool,
} }
s3Source := s3.Source{} s3Source := s3.Source{}
err = s3Source.Init(ctx, "trufflehog - s3", 0, int64(sourcespb.SourceType_SOURCE_TYPE_S3), true, &conn, runtime.NumCPU(), nil) err = s3Source.Init(ctx, "trufflehog - s3", 0, int64(sourcespb.SourceType_SOURCE_TYPE_S3), true, &conn, runtime.NumCPU())
if err != nil { if err != nil {
return errors.WrapPrefix(err, "failed to init S3 source", 0) return errors.WrapPrefix(err, "failed to init S3 source", 0)
} }

View file

@ -39,7 +39,7 @@ func (e *Engine) ScanSyslog(ctx context.Context, address, protocol, certPath, ke
return errors.WrapPrefix(err, "error unmarshalling connection", 0) return errors.WrapPrefix(err, "error unmarshalling connection", 0)
} }
source := syslog.Source{} source := syslog.Source{}
err = source.Init(ctx, "trufflehog - syslog", 0, 0, false, &conn, concurrency, nil) err = source.Init(ctx, "trufflehog - syslog", 0, 0, false, &conn, concurrency)
source.InjectConnection(connection) source.InjectConnection(connection)
if err != nil { if err != nil {
logrus.WithError(err).Error("failed to initialize syslog source") logrus.WithError(err).Error("failed to initialize syslog source")

File diff suppressed because it is too large Load diff

View file

@ -35,236 +35,6 @@ var (
_ = sort.Sort _ = sort.Sort
) )
// Validate checks the field values on TokenRequest with the rules defined in
// the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *TokenRequest) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on TokenRequest with the rules defined
// in the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in TokenRequestMultiError, or
// nil if none found.
func (m *TokenRequest) ValidateAll() error {
return m.validate(true)
}
func (m *TokenRequest) validate(all bool) error {
if m == nil {
return nil
}
var errors []error
// no validation rules for SourceId
if len(errors) > 0 {
return TokenRequestMultiError(errors)
}
return nil
}
// TokenRequestMultiError is an error wrapping multiple validation errors
// returned by TokenRequest.ValidateAll() if the designated constraints aren't met.
type TokenRequestMultiError []error
// Error returns a concatenation of all the error messages it wraps.
func (m TokenRequestMultiError) Error() string {
var msgs []string
for _, err := range m {
msgs = append(msgs, err.Error())
}
return strings.Join(msgs, "; ")
}
// AllErrors returns a list of validation violation errors.
func (m TokenRequestMultiError) AllErrors() []error { return m }
// TokenRequestValidationError is the validation error returned by
// TokenRequest.Validate if the designated constraints aren't met.
type TokenRequestValidationError struct {
field string
reason string
cause error
key bool
}
// Field function returns field value.
func (e TokenRequestValidationError) Field() string { return e.field }
// Reason function returns reason value.
func (e TokenRequestValidationError) Reason() string { return e.reason }
// Cause function returns cause value.
func (e TokenRequestValidationError) Cause() error { return e.cause }
// Key function returns key value.
func (e TokenRequestValidationError) Key() bool { return e.key }
// ErrorName returns error name.
func (e TokenRequestValidationError) ErrorName() string { return "TokenRequestValidationError" }
// Error satisfies the builtin error interface
func (e TokenRequestValidationError) Error() string {
cause := ""
if e.cause != nil {
cause = fmt.Sprintf(" | caused by: %v", e.cause)
}
key := ""
if e.key {
key = "key for "
}
return fmt.Sprintf(
"invalid %sTokenRequest.%s: %s%s",
key,
e.field,
e.reason,
cause)
}
var _ error = TokenRequestValidationError{}
var _ interface {
Field() string
Reason() string
Key() bool
Cause() error
ErrorName() string
} = TokenRequestValidationError{}
// Validate checks the field values on TokenResponse with the rules defined in
// the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations.
func (m *TokenResponse) Validate() error {
return m.validate(false)
}
// ValidateAll checks the field values on TokenResponse with the rules defined
// in the proto definition for this message. If any rules are violated, the
// result is a list of violation errors wrapped in TokenResponseMultiError, or
// nil if none found.
func (m *TokenResponse) ValidateAll() error {
return m.validate(true)
}
func (m *TokenResponse) validate(all bool) error {
if m == nil {
return nil
}
var errors []error
if all {
switch v := interface{}(m.GetTokenSource()).(type) {
case interface{ ValidateAll() error }:
if err := v.ValidateAll(); err != nil {
errors = append(errors, TokenResponseValidationError{
field: "TokenSource",
reason: "embedded message failed validation",
cause: err,
})
}
case interface{ Validate() error }:
if err := v.Validate(); err != nil {
errors = append(errors, TokenResponseValidationError{
field: "TokenSource",
reason: "embedded message failed validation",
cause: err,
})
}
}
} else if v, ok := interface{}(m.GetTokenSource()).(interface{ Validate() error }); ok {
if err := v.Validate(); err != nil {
return TokenResponseValidationError{
field: "TokenSource",
reason: "embedded message failed validation",
cause: err,
}
}
}
if len(errors) > 0 {
return TokenResponseMultiError(errors)
}
return nil
}
// TokenResponseMultiError is an error wrapping multiple validation errors
// returned by TokenResponse.ValidateAll() if the designated constraints
// aren't met.
type TokenResponseMultiError []error
// Error returns a concatenation of all the error messages it wraps.
func (m TokenResponseMultiError) Error() string {
var msgs []string
for _, err := range m {
msgs = append(msgs, err.Error())
}
return strings.Join(msgs, "; ")
}
// AllErrors returns a list of validation violation errors.
func (m TokenResponseMultiError) AllErrors() []error { return m }
// TokenResponseValidationError is the validation error returned by
// TokenResponse.Validate if the designated constraints aren't met.
type TokenResponseValidationError struct {
field string
reason string
cause error
key bool
}
// Field function returns field value.
func (e TokenResponseValidationError) Field() string { return e.field }
// Reason function returns reason value.
func (e TokenResponseValidationError) Reason() string { return e.reason }
// Cause function returns cause value.
func (e TokenResponseValidationError) Cause() error { return e.cause }
// Key function returns key value.
func (e TokenResponseValidationError) Key() bool { return e.key }
// ErrorName returns error name.
func (e TokenResponseValidationError) ErrorName() string { return "TokenResponseValidationError" }
// Error satisfies the builtin error interface
func (e TokenResponseValidationError) Error() string {
cause := ""
if e.cause != nil {
cause = fmt.Sprintf(" | caused by: %v", e.cause)
}
key := ""
if e.key {
key = "key for "
}
return fmt.Sprintf(
"invalid %sTokenResponse.%s: %s%s",
key,
e.field,
e.reason,
cause)
}
var _ error = TokenResponseValidationError{}
var _ interface {
Field() string
Reason() string
Key() bool
Cause() error
ErrorName() string
} = TokenResponseValidationError{}
// Validate checks the field values on LocalSource with the rules defined in // Validate checks the field values on LocalSource with the rules defined in
// the proto definition for this message. If any rules are violated, the first // the proto definition for this message. If any rules are violated, the first
// error encountered is returned, or nil if there are no violations. // error encountered is returned, or nil if there are no violations.

View file

@ -57,7 +57,7 @@ func (s *Source) JobID() int64 {
} }
// Init returns an initialized Filesystem source. // Init returns an initialized Filesystem source.
func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.log = log.WithField("source", s.Type()).WithField("name", name) s.log = log.WithField("source", s.Type()).WithField("name", name)
s.aCtx = aCtx s.aCtx = aCtx

View file

@ -60,7 +60,7 @@ func TestSource_Scan(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 5)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -84,7 +84,7 @@ func (s *Source) JobID() int64 {
} }
// Init returns an initialized GitHub source. // Init returns an initialized GitHub source.
func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.aCtx = aCtx s.aCtx = aCtx
s.name = name s.name = name

View file

@ -130,7 +130,7 @@ func TestSource_Scan(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, tt.init.concurrency, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, tt.init.concurrency)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -246,7 +246,7 @@ func TestSource_Chunks_Integration(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -387,7 +387,7 @@ func TestSource_Chunks_Edge_Cases(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 4)
if err != nil { if err != nil {
t.Errorf("Source.Init() error = %v", err) t.Errorf("Source.Init() error = %v", err)
return return

View file

@ -94,7 +94,7 @@ func (s *Source) Token(ctx context.Context, installationClient *github.Client) (
} }
// Init returns an initialized GitHub source. // Init returns an initialized GitHub source.
func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64, verify bool, connection *anypb.Any, concurrency int) error {
s.log = log.WithField("source", s.Type()).WithField("name", name) s.log = log.WithField("source", s.Type()).WithField("name", name)
s.aCtx = aCtx s.aCtx = aCtx

View file

@ -33,7 +33,7 @@ func createTestSource(src *sourcespb.GitHub) (*Source, *anypb.Any) {
func initTestSource(src *sourcespb.GitHub) *Source { func initTestSource(src *sourcespb.GitHub) *Source {
s, conn := createTestSource(src) s, conn := createTestSource(src)
if err := s.Init(context.TODO(), "test - github", 0, 1337, false, conn, 1, nil); err != nil { if err := s.Init(context.TODO(), "test - github", 0, 1337, false, conn, 1); err != nil {
panic(err) panic(err)
} }
return s return s
@ -47,7 +47,7 @@ func TestInit(t *testing.T) {
}, },
}) })
err := source.Init(context.TODO(), "test - github", 0, 1337, false, conn, 1, nil) err := source.Init(context.TODO(), "test - github", 0, 1337, false, conn, 1)
assert.Nil(t, err) assert.Nil(t, err)
// TODO: test error case // TODO: test error case

View file

@ -60,7 +60,7 @@ func (s *Source) JobID() int64 {
} }
// Init returns an initialized Gitlab source. // Init returns an initialized Gitlab source.
func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.aCtx = aCtx s.aCtx = aCtx
s.name = name s.name = name

View file

@ -130,7 +130,7 @@ func TestSource_Scan(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -55,7 +55,7 @@ func (s *Source) JobID() int64 {
} }
// Init returns an initialized AWS source // Init returns an initialized AWS source
func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.log = log.WithField("source", s.Type()).WithField("name", name) s.log = log.WithField("source", s.Type()).WithField("name", name)
s.aCtx = aCtx s.aCtx = aCtx

View file

@ -73,7 +73,7 @@ func TestSource_Chunks(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10, nil) err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -35,7 +35,7 @@ type Source interface {
// JobID returns the initialized job ID used for tracking relationships in the DB. // JobID returns the initialized job ID used for tracking relationships in the DB.
JobID() int64 JobID() int64
// Init initializes the source. // Init initializes the source.
Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, tokenSrc *sourcespb.TokenServiceClient) error Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error
// Chunks emits data over a channel that is decoded and scanned for secrets. // Chunks emits data over a channel that is decoded and scanned for secrets.
Chunks(ctx context.Context, chunksChan chan *Chunk) error Chunks(ctx context.Context, chunksChan chan *Chunk) error
// Completion Percentage for Scanned Source // Completion Percentage for Scanned Source

View file

@ -83,7 +83,7 @@ func (s *Source) InjectConnection(conn *sourcespb.Syslog) {
} }
// Init returns an initialized Syslog source. // Init returns an initialized Syslog source.
func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int, _ *sourcespb.TokenServiceClient) error { func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, verify bool, connection *anypb.Any, concurrency int) error {
s.aCtx = aCtx s.aCtx = aCtx
s.name = name s.name = name

View file

@ -8,21 +8,8 @@ import "validate/validate.proto";
import "credentials.proto"; import "credentials.proto";
import "google/protobuf/any.proto"; import "google/protobuf/any.proto";
import "google/protobuf/duration.proto"; import "google/protobuf/duration.proto";
import "google/protobuf/struct.proto";
message TokenRequest {
int64 source_id = 1;
}
message TokenResponse {
google.protobuf.Struct token_source = 1;
}
service TokenService {
rpc Token(TokenRequest) returns (TokenResponse);
}
enum SourceType { enum SourceType {
SOURCE_TYPE_AZURE_STORAGE = 0; SOURCE_TYPE_AZURE_STORAGE = 0;
SOURCE_TYPE_BITBUCKET = 1; SOURCE_TYPE_BITBUCKET = 1;