Add user agent suffix feature flag (#3297)

* Add user agent suffix feature flag

* unecessary concat
This commit is contained in:
Dustin Decker 2024-09-13 15:20:43 -07:00 committed by GitHub
parent 213bf7e4fd
commit 7e78ca385f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 55 additions and 15 deletions

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/go-retryablehttp"
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
) )
var caCerts = []string{ var caCerts = []string{
@ -88,8 +89,15 @@ type CustomTransport struct {
T http.RoundTripper T http.RoundTripper
} }
func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}
func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", "TruffleHog") req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req) return t.T.RoundTrip(req)
} }

View file

@ -6,22 +6,30 @@ import (
"net" "net"
"net/http" "net/http"
"time" "time"
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
) )
var DetectorHttpClientWithNoLocalAddresses *http.Client var DetectorHttpClientWithNoLocalAddresses *http.Client
var DetectorHttpClientWithLocalAddresses *http.Client var DetectorHttpClientWithLocalAddresses *http.Client
const DefaultResponseTimeout = 5 * time.Second const DefaultResponseTimeout = 5 * time.Second
const DefaultUserAgent = "TruffleHog"
func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}
func init() { func init() {
DetectorHttpClientWithLocalAddresses = NewDetectorHttpClient( DetectorHttpClientWithLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)), WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout), WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(), WithNoFollowRedirects(),
) )
DetectorHttpClientWithNoLocalAddresses = NewDetectorHttpClient( DetectorHttpClientWithNoLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)), WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout), WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(), WithNoFollowRedirects(),
WithNoLocalIP(), WithNoLocalIP(),
@ -41,12 +49,11 @@ func WithNoFollowRedirects() ClientOption {
} }
type detectorTransport struct { type detectorTransport struct {
T http.RoundTripper T http.RoundTripper
userAgent string
} }
func (t *detectorTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *detectorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", t.userAgent) req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req) return t.T.RoundTrip(req)
} }
@ -55,7 +62,7 @@ var defaultDialer = &net.Dialer{
KeepAlive: 5 * time.Second, KeepAlive: 5 * time.Second,
} }
func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripper { func NewDetectorTransport(T http.RoundTripper) http.RoundTripper {
if T == nil { if T == nil {
T = &http.Transport{ T = &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@ -67,7 +74,7 @@ func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripp
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
} }
return &detectorTransport{T: T, userAgent: userAgent} return &detectorTransport{T: T}
} }
func isLocalIP(ip net.IP) bool { func isLocalIP(ip net.IP) bool {
@ -143,7 +150,7 @@ func WithTimeout(timeout time.Duration) ClientOption {
func NewDetectorHttpClient(opts ...ClientOption) *http.Client { func NewDetectorHttpClient(opts ...ClientOption) *http.Client {
httpClient := &http.Client{ httpClient := &http.Client{
Transport: NewDetectorTransport(DefaultUserAgent, nil), Transport: NewDetectorTransport(nil),
Timeout: DefaultResponseTimeout, Timeout: DefaultResponseTimeout,
} }

View file

@ -3,7 +3,32 @@ package feature
import "sync/atomic" import "sync/atomic"
var ( var (
ForceSkipBinaries = atomic.Bool{} ForceSkipBinaries atomic.Bool
ForceSkipArchives = atomic.Bool{} ForceSkipArchives atomic.Bool
SkipAdditionalRefs = atomic.Bool{} SkipAdditionalRefs atomic.Bool
UserAgentSuffix AtomicString
) )
type AtomicString struct {
value atomic.Value
}
// Load returns the current value of the atomic string
func (as *AtomicString) Load() string {
if v := as.value.Load(); v != nil {
return v.(string)
}
return ""
}
// Store sets the value of the atomic string
func (as *AtomicString) Store(newValue string) {
as.value.Store(newValue)
}
// Swap atomically swaps the current string with a new one and returns the old value
func (as *AtomicString) Swap(newValue string) string {
oldValue := as.Load()
as.Store(newValue)
return oldValue
}

View file

@ -63,7 +63,7 @@ func TestSource_Token(t *testing.T) {
s.Init(ctx, "github integration test source", 0, 0, false, conn, 1) s.Init(ctx, "github integration test source", 0, 0, false, conn, 1)
s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil) s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil)
err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient()) err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient(), noopReporter())
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = s.cloneRepo(ctx, "https://github.com/truffle-test-integration-org/another-test-repo.git") _, _, err = s.cloneRepo(ctx, "https://github.com/truffle-test-integration-org/another-test-repo.git")
@ -631,7 +631,7 @@ func TestSource_paginateGists(t *testing.T) {
} }
chunksCh := make(chan *sources.Chunk, 5) chunksCh := make(chan *sources.Chunk, 5)
go func() { go func() {
assert.NoError(t, s.addUserGistsToCache(ctx, tt.user)) assert.NoError(t, s.addUserGistsToCache(ctx, tt.user, noopReporter()))
chunksCh <- &sources.Chunk{} chunksCh <- &sources.Chunk{}
}() }()
var wantedRepo string var wantedRepo string