[feat] - Make the client configurable (#2528)

* Make the client configurable

* add comment

* add backoff option
This commit is contained in:
ahrav 2024-03-01 13:29:25 -08:00 committed by GitHub
parent 7620906b07
commit 3da0c5e125
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 287 additions and 9 deletions

View file

@ -115,6 +115,39 @@ func ConstantResponseHttpClient(statusCode int, body string) *http.Client {
}
}
// ClientOption configures how we set up the client.
type ClientOption func(*retryablehttp.Client)
// WithCheckRetry allows setting a custom CheckRetry policy.
func WithCheckRetry(cr retryablehttp.CheckRetry) ClientOption {
return func(c *retryablehttp.Client) { c.CheckRetry = cr }
}
// WithBackoff allows setting a custom backoff policy.
func WithBackoff(b retryablehttp.Backoff) ClientOption {
return func(c *retryablehttp.Client) { c.Backoff = b }
}
// WithTimeout allows setting a custom timeout.
func WithTimeout(timeout time.Duration) ClientOption {
return func(c *retryablehttp.Client) { c.HTTPClient.Timeout = timeout }
}
// WithMaxRetries allows setting a custom maximum number of retries.
func WithMaxRetries(retries int) ClientOption {
return func(c *retryablehttp.Client) { c.RetryMax = retries }
}
// WithRetryWaitMin allows setting a custom minimum retry wait.
func WithRetryWaitMin(wait time.Duration) ClientOption {
return func(c *retryablehttp.Client) { c.RetryWaitMin = wait }
}
// WithRetryWaitMax allows setting a custom maximum retry wait.
func WithRetryWaitMax(wait time.Duration) ClientOption {
return func(c *retryablehttp.Client) { c.RetryWaitMax = wait }
}
func PinnedRetryableHttpClient() *http.Client {
httpClient := retryablehttp.NewClient()
httpClient.Logger = nil
@ -136,21 +169,28 @@ func PinnedRetryableHttpClient() *http.Client {
return httpClient.StandardClient()
}
func RetryableHttpClient() *http.Client {
func RetryableHTTPClient(opts ...ClientOption) *http.Client {
httpClient := retryablehttp.NewClient()
httpClient.RetryMax = 3
httpClient.Logger = nil
httpClient.HTTPClient.Timeout = 3 * time.Second
httpClient.HTTPClient.Transport = NewCustomTransport(nil)
for _, opt := range opts {
opt(httpClient)
}
return httpClient.StandardClient()
}
func RetryableHttpClientTimeout(timeOutSeconds int64) *http.Client {
func RetryableHTTPClientTimeout(timeOutSeconds int64, opts ...ClientOption) *http.Client {
httpClient := retryablehttp.NewClient()
httpClient.RetryMax = 3
httpClient.Logger = nil
httpClient.HTTPClient.Timeout = time.Duration(timeOutSeconds) * time.Second
httpClient.HTTPClient.Transport = NewCustomTransport(nil)
for _, opt := range opts {
opt(httpClient)
}
return httpClient.StandardClient()
}

237
pkg/common/http_test.go Normal file
View file

@ -0,0 +1,237 @@
package common
import (
"context"
"math"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
)
func TestRetryableHTTPClientCheckRetry(t *testing.T) {
testCases := []struct {
name string
responseStatus int
checkRetry retryablehttp.CheckRetry
expectedRetries int
}{
{
name: "Retry on 500 status, give up after 3 retries",
responseStatus: http.StatusInternalServerError, // Server error status
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
// The underlying transport will retry on 500 status.
if resp.StatusCode == http.StatusInternalServerError {
return true, nil
}
return false, nil
},
expectedRetries: 3,
},
{
name: "No retry on 400 status",
responseStatus: http.StatusBadRequest, // Client error status
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
// Do not retry on client errors.
return false, nil
},
expectedRetries: 0,
},
{
name: "Retry on 429 status, give up after 3 retries",
responseStatus: http.StatusTooManyRequests,
checkRetry: func(ctx context.Context, resp *http.Response, err error) (bool, error) {
// The underlying transport will retry on 429 status.
if resp.StatusCode == http.StatusTooManyRequests {
return true, nil
}
return false, nil
},
expectedRetries: 3,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var retryCount int
// Do not count the initial request as a retry.
i := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if i != 0 {
retryCount++
}
i++
w.WriteHeader(tc.responseStatus)
}))
defer server.Close()
ctx := context.Background()
client := RetryableHTTPClient(WithCheckRetry(tc.checkRetry), WithTimeout(10*time.Millisecond), WithRetryWaitMin(1*time.Millisecond))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
assert.NoError(t, err)
// Bad linter, there is no body to close.
_, err = client.Do(req) //nolint:bodyclose
if err != nil && slices.Contains([]int{http.StatusInternalServerError, http.StatusTooManyRequests}, tc.responseStatus) {
// The underlying transport will retry on 500 and 429 status.
assert.Error(t, err)
}
assert.Equal(t, tc.expectedRetries, retryCount, "Retry count does not match expected")
})
}
}
func TestRetryableHTTPClientMaxRetry(t *testing.T) {
testCases := []struct {
name string
responseStatus int
maxRetries int
expectedRetries int
}{
{
name: "Max retries with 500 status",
responseStatus: http.StatusInternalServerError,
maxRetries: 2,
expectedRetries: 2,
},
{
name: "Max retries with 429 status",
responseStatus: http.StatusTooManyRequests,
maxRetries: 1,
expectedRetries: 1,
},
{
name: "Max retries with 200 status",
responseStatus: http.StatusOK,
maxRetries: 3,
expectedRetries: 0,
},
{
name: "Max retries with 400 status",
responseStatus: http.StatusBadRequest,
maxRetries: 3,
expectedRetries: 0,
},
{
name: "Max retries with 401 status",
responseStatus: http.StatusUnauthorized,
maxRetries: 3,
expectedRetries: 0,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var retryCount int
i := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if i != 0 {
retryCount++
}
i++
w.WriteHeader(tc.responseStatus)
}))
defer server.Close()
client := RetryableHTTPClient(
WithMaxRetries(tc.maxRetries),
WithTimeout(10*time.Millisecond),
WithRetryWaitMin(1*time.Millisecond),
)
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
assert.NoError(t, err)
// Bad linter, there is no body to close.
_, err = client.Do(req) //nolint:bodyclose
if err != nil && tc.responseStatus == http.StatusOK {
assert.Error(t, err)
}
assert.Equal(t, tc.expectedRetries, retryCount, "Retry count does not match expected")
})
}
}
func TestRetryableHTTPClientBackoff(t *testing.T) {
testCases := []struct {
name string
responseStatus int
expectedRetries int
backoffPolicy retryablehttp.Backoff
expectedBackoffs []time.Duration
}{
{
name: "Custom backoff on 500 status",
responseStatus: http.StatusInternalServerError,
expectedRetries: 3,
backoffPolicy: func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
switch attemptNum {
case 1:
return 1 * time.Millisecond
case 2:
return 2 * time.Millisecond
case 3:
return 4 * time.Millisecond
default:
return max
}
},
expectedBackoffs: []time.Duration{1 * time.Millisecond, 2 * time.Millisecond, 4 * time.Millisecond},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var actualBackoffs []time.Duration
var lastTime time.Time
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
now := time.Now()
if !lastTime.IsZero() {
actualBackoffs = append(actualBackoffs, now.Sub(lastTime))
}
lastTime = now
w.WriteHeader(tc.responseStatus)
}))
defer server.Close()
ctx := context.Background()
client := RetryableHTTPClient(
WithBackoff(tc.backoffPolicy),
WithTimeout(10*time.Millisecond),
WithRetryWaitMin(1*time.Millisecond),
WithRetryWaitMax(10*time.Millisecond),
)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
assert.NoError(t, err)
_, err = client.Do(req) //nolint:bodyclose
assert.Error(t, err, "Expected error due to 500 status")
assert.Len(t, actualBackoffs, tc.expectedRetries, "Unexpected number of backoffs")
for i, expectedBackoff := range tc.expectedBackoffs {
if i < len(actualBackoffs) {
// Allow some deviation in timing due to processing delays.
assert.Less(t, math.Abs(float64(actualBackoffs[i]-expectedBackoff)), float64(10*time.Millisecond), "Unexpected backoff duration")
}
}
})
}
}

View file

@ -4,12 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
regexp "github.com/wasilibs/go-re2"
"net/http"
"strings"
"sync"
"time"
regexp "github.com/wasilibs/go-re2"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"golang.org/x/crypto/ssh"
@ -27,7 +28,7 @@ var _ detectors.Detector = (*Scanner)(nil)
var (
// TODO: add base64 encoded key support
client = common.RetryableHttpClient()
client = common.RetryableHTTPClient()
keyPat = regexp.MustCompile(`(?i)-----\s*?BEGIN[ A-Z0-9_-]*?PRIVATE KEY\s*?-----[\s\S]*?----\s*?END[ A-Z0-9_-]*? PRIVATE KEY\s*?-----`)
)

View file

@ -63,7 +63,7 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc
s.verify = verify
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)
s.client = common.RetryableHttpClientTimeout(3)
s.client = common.RetryableHTTPClientTimeout(3)
var conn sourcespb.CircleCI
if err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}); err != nil {

View file

@ -218,7 +218,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
s.jobPool = &errgroup.Group{}
s.jobPool.SetLimit(concurrency)
s.httpClient = common.RetryableHttpClientTimeout(60)
s.httpClient = common.RetryableHTTPClientTimeout(60)
s.apiClient = github.NewClient(s.httpClient)
var conn sourcespb.GitHub
@ -694,7 +694,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
appItr.BaseURL = apiEndpoint
// Does this need to be separate from |s.httpClient|?
instHTTPClient := common.RetryableHttpClientTimeout(60)
instHTTPClient := common.RetryableHTTPClientTimeout(60)
instHTTPClient.Transport = appItr
installationClient, err = github.NewClient(instHTTPClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
if err != nil {

View file

@ -75,7 +75,7 @@ func (s *Source) Init(ctx context.Context, name string, jobId sources.JobID, sou
return errors.New("token is empty")
}
s.client = travis.NewClient(baseURL, conn.GetToken())
s.client.HTTPClient = common.RetryableHttpClientTimeout(3)
s.client.HTTPClient = common.RetryableHTTPClientTimeout(3)
user, _, err := s.client.User.Current(ctx, nil)
if err != nil {