Add user agent suffix feature flag (#3297)

* Add user agent suffix feature flag

* unecessary concat
This commit is contained in:
Dustin Decker 2024-09-13 15:20:43 -07:00 committed by GitHub
parent 213bf7e4fd
commit 7e78ca385f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 55 additions and 15 deletions

View file

@ -10,6 +10,7 @@ import (
"time"
"github.com/hashicorp/go-retryablehttp"
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
)
var caCerts = []string{
@ -88,8 +89,15 @@ type CustomTransport struct {
T http.RoundTripper
}
func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}
func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", "TruffleHog")
req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req)
}

View file

@ -6,22 +6,30 @@ import (
"net"
"net/http"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
)
var DetectorHttpClientWithNoLocalAddresses *http.Client
var DetectorHttpClientWithLocalAddresses *http.Client
const DefaultResponseTimeout = 5 * time.Second
const DefaultUserAgent = "TruffleHog"
func userAgent() string {
if len(feature.UserAgentSuffix.Load()) > 0 {
return "TruffleHog " + feature.UserAgentSuffix.Load()
}
return "TruffleHog"
}
func init() {
DetectorHttpClientWithLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)),
WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(),
)
DetectorHttpClientWithNoLocalAddresses = NewDetectorHttpClient(
WithTransport(NewDetectorTransport(DefaultUserAgent, nil)),
WithTransport(NewDetectorTransport(nil)),
WithTimeout(DefaultResponseTimeout),
WithNoFollowRedirects(),
WithNoLocalIP(),
@ -41,12 +49,11 @@ func WithNoFollowRedirects() ClientOption {
}
type detectorTransport struct {
T http.RoundTripper
userAgent string
T http.RoundTripper
}
func (t *detectorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", t.userAgent)
req.Header.Add("User-Agent", userAgent())
return t.T.RoundTrip(req)
}
@ -55,7 +62,7 @@ var defaultDialer = &net.Dialer{
KeepAlive: 5 * time.Second,
}
func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripper {
func NewDetectorTransport(T http.RoundTripper) http.RoundTripper {
if T == nil {
T = &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -67,7 +74,7 @@ func NewDetectorTransport(userAgent string, T http.RoundTripper) http.RoundTripp
ExpectContinueTimeout: 1 * time.Second,
}
}
return &detectorTransport{T: T, userAgent: userAgent}
return &detectorTransport{T: T}
}
func isLocalIP(ip net.IP) bool {
@ -143,7 +150,7 @@ func WithTimeout(timeout time.Duration) ClientOption {
func NewDetectorHttpClient(opts ...ClientOption) *http.Client {
httpClient := &http.Client{
Transport: NewDetectorTransport(DefaultUserAgent, nil),
Transport: NewDetectorTransport(nil),
Timeout: DefaultResponseTimeout,
}

View file

@ -3,7 +3,32 @@ package feature
import "sync/atomic"
var (
ForceSkipBinaries = atomic.Bool{}
ForceSkipArchives = atomic.Bool{}
SkipAdditionalRefs = atomic.Bool{}
ForceSkipBinaries atomic.Bool
ForceSkipArchives atomic.Bool
SkipAdditionalRefs atomic.Bool
UserAgentSuffix AtomicString
)
type AtomicString struct {
value atomic.Value
}
// Load returns the current value of the atomic string
func (as *AtomicString) Load() string {
if v := as.value.Load(); v != nil {
return v.(string)
}
return ""
}
// Store sets the value of the atomic string
func (as *AtomicString) Store(newValue string) {
as.value.Store(newValue)
}
// Swap atomically swaps the current string with a new one and returns the old value
func (as *AtomicString) Swap(newValue string) string {
oldValue := as.Load()
as.Store(newValue)
return oldValue
}

View file

@ -63,7 +63,7 @@ func TestSource_Token(t *testing.T) {
s.Init(ctx, "github integration test source", 0, 0, false, conn, 1)
s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil)
err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient())
err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient(), noopReporter())
assert.NoError(t, err)
_, _, err = s.cloneRepo(ctx, "https://github.com/truffle-test-integration-org/another-test-repo.git")
@ -631,7 +631,7 @@ func TestSource_paginateGists(t *testing.T) {
}
chunksCh := make(chan *sources.Chunk, 5)
go func() {
assert.NoError(t, s.addUserGistsToCache(ctx, tt.user))
assert.NoError(t, s.addUserGistsToCache(ctx, tt.user, noopReporter()))
chunksCh <- &sources.Chunk{}
}()
var wantedRepo string