mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[feat] - Make the client configurable (#2528)
* Make the client configurable * add comment * add backoff option
This commit is contained in:
parent
7620906b07
commit
3da0c5e125
6 changed files with 287 additions and 9 deletions
|
@ -115,6 +115,39 @@ func ConstantResponseHttpClient(statusCode int, body string) *http.Client {
|
|||
}
|
||||
}
|
||||
|
||||
// ClientOption configures how we set up the client.
|
||||
type ClientOption func(*retryablehttp.Client)
|
||||
|
||||
// WithCheckRetry allows setting a custom CheckRetry policy.
|
||||
func WithCheckRetry(cr retryablehttp.CheckRetry) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.CheckRetry = cr }
|
||||
}
|
||||
|
||||
// WithBackoff allows setting a custom backoff policy.
|
||||
func WithBackoff(b retryablehttp.Backoff) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.Backoff = b }
|
||||
}
|
||||
|
||||
// WithTimeout allows setting a custom timeout.
|
||||
func WithTimeout(timeout time.Duration) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.HTTPClient.Timeout = timeout }
|
||||
}
|
||||
|
||||
// WithMaxRetries allows setting a custom maximum number of retries.
|
||||
func WithMaxRetries(retries int) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.RetryMax = retries }
|
||||
}
|
||||
|
||||
// WithRetryWaitMin allows setting a custom minimum retry wait.
|
||||
func WithRetryWaitMin(wait time.Duration) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.RetryWaitMin = wait }
|
||||
}
|
||||
|
||||
// WithRetryWaitMax allows setting a custom maximum retry wait.
|
||||
func WithRetryWaitMax(wait time.Duration) ClientOption {
|
||||
return func(c *retryablehttp.Client) { c.RetryWaitMax = wait }
|
||||
}
|
||||
|
||||
func PinnedRetryableHttpClient() *http.Client {
|
||||
httpClient := retryablehttp.NewClient()
|
||||
httpClient.Logger = nil
|
||||
|
@ -136,21 +169,28 @@ func PinnedRetryableHttpClient() *http.Client {
|
|||
return httpClient.StandardClient()
|
||||
}
|
||||
|
||||
func RetryableHttpClient() *http.Client {
|
||||
func RetryableHTTPClient(opts ...ClientOption) *http.Client {
|
||||
httpClient := retryablehttp.NewClient()
|
||||
httpClient.RetryMax = 3
|
||||
httpClient.Logger = nil
|
||||
httpClient.HTTPClient.Timeout = 3 * time.Second
|
||||
httpClient.HTTPClient.Transport = NewCustomTransport(nil)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(httpClient)
|
||||
}
|
||||
return httpClient.StandardClient()
|
||||
}
|
||||
|
||||
func RetryableHttpClientTimeout(timeOutSeconds int64) *http.Client {
|
||||
func RetryableHTTPClientTimeout(timeOutSeconds int64, opts ...ClientOption) *http.Client {
|
||||
httpClient := retryablehttp.NewClient()
|
||||
httpClient.RetryMax = 3
|
||||
httpClient.Logger = nil
|
||||
httpClient.HTTPClient.Timeout = time.Duration(timeOutSeconds) * time.Second
|
||||
httpClient.HTTPClient.Transport = NewCustomTransport(nil)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(httpClient)
|
||||
}
|
||||
return httpClient.StandardClient()
|
||||
}
|
||||
|
||||
|
|
237
pkg/common/http_test.go
Normal file
237
pkg/common/http_test.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestRetryableHTTPClientCheckRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
responseStatus int
|
||||
checkRetry retryablehttp.CheckRetry
|
||||
expectedRetries int
|
||||
}{
|
||||
{
|
||||
name: "Retry on 500 status, give up after 3 retries",
|
||||
responseStatus: http.StatusInternalServerError, // Server error status
|
||||
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
||||
// The underlying transport will retry on 500 status.
|
||||
if resp.StatusCode == http.StatusInternalServerError {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
},
|
||||
expectedRetries: 3,
|
||||
},
|
||||
{
|
||||
name: "No retry on 400 status",
|
||||
responseStatus: http.StatusBadRequest, // Client error status
|
||||
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
||||
// Do not retry on client errors.
|
||||
return false, nil
|
||||
},
|
||||
expectedRetries: 0,
|
||||
},
|
||||
{
|
||||
name: "Retry on 429 status, give up after 3 retries",
|
||||
responseStatus: http.StatusTooManyRequests,
|
||||
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
||||
// The underlying transport will retry on 429 status.
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
},
|
||||
expectedRetries: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var retryCount int
|
||||
|
||||
// Do not count the initial request as a retry.
|
||||
i := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if i != 0 {
|
||||
retryCount++
|
||||
}
|
||||
i++
|
||||
w.WriteHeader(tc.responseStatus)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
client := RetryableHTTPClient(WithCheckRetry(tc.checkRetry), WithTimeout(10*time.Millisecond), WithRetryWaitMin(1*time.Millisecond))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Bad linter, there is no body to close.
|
||||
_, err = client.Do(req) //nolint:bodyclose
|
||||
if err != nil && slices.Contains([]int{http.StatusInternalServerError, http.StatusTooManyRequests}, tc.responseStatus) {
|
||||
// The underlying transport will retry on 500 and 429 status.
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRetries, retryCount, "Retry count does not match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryableHTTPClientMaxRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
responseStatus int
|
||||
maxRetries int
|
||||
expectedRetries int
|
||||
}{
|
||||
{
|
||||
name: "Max retries with 500 status",
|
||||
responseStatus: http.StatusInternalServerError,
|
||||
maxRetries: 2,
|
||||
expectedRetries: 2,
|
||||
},
|
||||
{
|
||||
name: "Max retries with 429 status",
|
||||
responseStatus: http.StatusTooManyRequests,
|
||||
maxRetries: 1,
|
||||
expectedRetries: 1,
|
||||
},
|
||||
{
|
||||
name: "Max retries with 200 status",
|
||||
responseStatus: http.StatusOK,
|
||||
maxRetries: 3,
|
||||
expectedRetries: 0,
|
||||
},
|
||||
{
|
||||
name: "Max retries with 400 status",
|
||||
responseStatus: http.StatusBadRequest,
|
||||
maxRetries: 3,
|
||||
expectedRetries: 0,
|
||||
},
|
||||
{
|
||||
name: "Max retries with 401 status",
|
||||
responseStatus: http.StatusUnauthorized,
|
||||
maxRetries: 3,
|
||||
expectedRetries: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var retryCount int
|
||||
|
||||
i := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if i != 0 {
|
||||
retryCount++
|
||||
}
|
||||
i++
|
||||
w.WriteHeader(tc.responseStatus)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := RetryableHTTPClient(
|
||||
WithMaxRetries(tc.maxRetries),
|
||||
WithTimeout(10*time.Millisecond),
|
||||
WithRetryWaitMin(1*time.Millisecond),
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Bad linter, there is no body to close.
|
||||
_, err = client.Do(req) //nolint:bodyclose
|
||||
if err != nil && tc.responseStatus == http.StatusOK {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRetries, retryCount, "Retry count does not match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryableHTTPClientBackoff(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
responseStatus int
|
||||
expectedRetries int
|
||||
backoffPolicy retryablehttp.Backoff
|
||||
expectedBackoffs []time.Duration
|
||||
}{
|
||||
{
|
||||
name: "Custom backoff on 500 status",
|
||||
responseStatus: http.StatusInternalServerError,
|
||||
expectedRetries: 3,
|
||||
backoffPolicy: func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
|
||||
switch attemptNum {
|
||||
case 1:
|
||||
return 1 * time.Millisecond
|
||||
case 2:
|
||||
return 2 * time.Millisecond
|
||||
case 3:
|
||||
return 4 * time.Millisecond
|
||||
default:
|
||||
return max
|
||||
}
|
||||
},
|
||||
expectedBackoffs: []time.Duration{1 * time.Millisecond, 2 * time.Millisecond, 4 * time.Millisecond},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualBackoffs []time.Duration
|
||||
var lastTime time.Time
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
if !lastTime.IsZero() {
|
||||
actualBackoffs = append(actualBackoffs, now.Sub(lastTime))
|
||||
}
|
||||
lastTime = now
|
||||
w.WriteHeader(tc.responseStatus)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
client := RetryableHTTPClient(
|
||||
WithBackoff(tc.backoffPolicy),
|
||||
WithTimeout(10*time.Millisecond),
|
||||
WithRetryWaitMin(1*time.Millisecond),
|
||||
WithRetryWaitMax(10*time.Millisecond),
|
||||
)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = client.Do(req) //nolint:bodyclose
|
||||
assert.Error(t, err, "Expected error due to 500 status")
|
||||
|
||||
assert.Len(t, actualBackoffs, tc.expectedRetries, "Unexpected number of backoffs")
|
||||
|
||||
for i, expectedBackoff := range tc.expectedBackoffs {
|
||||
if i < len(actualBackoffs) {
|
||||
// Allow some deviation in timing due to processing delays.
|
||||
assert.Less(t, math.Abs(float64(actualBackoffs[i]-expectedBackoff)), float64(10*time.Millisecond), "Unexpected backoff duration")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,12 +4,13 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
regexp "github.com/wasilibs/go-re2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
regexp "github.com/wasilibs/go-re2"
|
||||
|
||||
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
|
@ -27,7 +28,7 @@ var _ detectors.Detector = (*Scanner)(nil)
|
|||
|
||||
var (
|
||||
// TODO: add base64 encoded key support
|
||||
client = common.RetryableHttpClient()
|
||||
client = common.RetryableHTTPClient()
|
||||
keyPat = regexp.MustCompile(`(?i)-----\s*?BEGIN[ A-Z0-9_-]*?PRIVATE KEY\s*?-----[\s\S]*?----\s*?END[ A-Z0-9_-]*? PRIVATE KEY\s*?-----`)
|
||||
)
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc
|
|||
s.verify = verify
|
||||
s.jobPool = &errgroup.Group{}
|
||||
s.jobPool.SetLimit(concurrency)
|
||||
s.client = common.RetryableHttpClientTimeout(3)
|
||||
s.client = common.RetryableHTTPClientTimeout(3)
|
||||
|
||||
var conn sourcespb.CircleCI
|
||||
if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil {
|
||||
|
|
|
@ -218,7 +218,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
|
|||
s.jobPool = &errgroup.Group{}
|
||||
s.jobPool.SetLimit(concurrency)
|
||||
|
||||
s.httpClient = common.RetryableHttpClientTimeout(60)
|
||||
s.httpClient = common.RetryableHTTPClientTimeout(60)
|
||||
s.apiClient = github.NewClient(s.httpClient)
|
||||
|
||||
var conn sourcespb.GitHub
|
||||
|
@ -694,7 +694,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
|
|||
appItr.BaseURL = apiEndpoint
|
||||
|
||||
// Does this need to be separate from |s.httpClient|?
|
||||
instHTTPClient := common.RetryableHttpClientTimeout(60)
|
||||
instHTTPClient := common.RetryableHTTPClientTimeout(60)
|
||||
instHTTPClient.Transport = appItr
|
||||
installationClient, err = github.NewClient(instHTTPClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
|
||||
if err != nil {
|
||||
|
|
|
@ -75,7 +75,7 @@ func (s *Source) Init(ctx context.Context, name string, jobId sources.JobID, sou
|
|||
return errors.New("token is empty")
|
||||
}
|
||||
s.client = travis.NewClient(baseURL, conn.GetToken())
|
||||
s.client.HTTPClient = common.RetryableHttpClientTimeout(3)
|
||||
s.client.HTTPClient = common.RetryableHTTPClientTimeout(3)
|
||||
|
||||
user, _, err := s.client.User.Current(ctx, nil)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue