diff --git a/cmd/processRequest.go b/cmd/processRequest.go index cb2ce69..8080115 100644 --- a/cmd/processRequest.go +++ b/cmd/processRequest.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io/ioutil" "log" @@ -27,11 +28,12 @@ type responseWithHeader struct { // RequestProcessor handles incoming requests. type RequestProcessor struct { - peakRequest30 sync.Map - peakRequest60 sync.Map - lruCache *lru.Cache - stats *Stats - router routing.Router + peakRequest30 sync.Map + peakRequest60 sync.Map + lruCache *lru.Cache + stats *Stats + router routing.Router + upstreamTransport *http.Transport } // NewRequestProcessor returns new RequestProcessor. @@ -41,9 +43,22 @@ func NewRequestProcessor() (*RequestProcessor, error) { return nil, err } + dialer := &net.Dialer{ + Timeout: uplinkTimeout * time.Second, + KeepAlive: uplinkTimeout * time.Second, + DualStack: true, + } + + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, network, uplinkSrvAddr) + }, + } + rp := &RequestProcessor{ - lruCache: lruCache, - stats: NewStats(), + lruCache: lruCache, + stats: NewStats(), + upstreamTransport: transport, } // Initialize routes. @@ -80,7 +95,7 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader if dontCache(r) { rp.stats.Inc("uncached") - return get(r) + return get(r, rp.upstreamTransport) } cacheDigest := getCacheDigest(r) @@ -127,7 +142,7 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader rp.lruCache.Add(cacheDigest, responseWithHeader{InProgress: true}) - response, err = get(r) + response, err = get(r, rp.upstreamTransport) if err != nil { return nil, err } @@ -141,9 +156,11 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader return response, nil } -func get(req *http.Request) (*responseWithHeader, error) { +func get(req *http.Request, transport *http.Transport) (*responseWithHeader, error) { - client := &http.Client{} + client := &http.Client{ + Transport: transport, + } queryURL := fmt.Sprintf("http://%s%s", req.Host, req.RequestURI) diff --git a/cmd/srv.go b/cmd/srv.go index 0db5a9d..28bd2ce 100644 --- a/cmd/srv.go +++ b/cmd/srv.go @@ -1,12 +1,10 @@ package main import ( - "context" "crypto/tls" "fmt" "io" "log" - "net" "net/http" "time" @@ -35,18 +33,6 @@ var plainTextAgents = []string{ "xh", } -func init() { - dialer := &net.Dialer{ - Timeout: uplinkTimeout * time.Second, - KeepAlive: uplinkTimeout * time.Second, - DualStack: true, - } - - http.DefaultTransport.(*http.Transport).DialContext = func(ctx context.Context, network, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, network, uplinkSrvAddr) - } -} - func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv {