Move global transport to RequestProcessor

This commit is contained in:
Igor Chubin 2022-12-02 19:47:36 +01:00
parent 7b8c6665e8
commit 28f1fd9aae
2 changed files with 28 additions and 25 deletions

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -27,11 +28,12 @@ type responseWithHeader struct {
// RequestProcessor handles incoming requests. // RequestProcessor handles incoming requests.
type RequestProcessor struct { type RequestProcessor struct {
peakRequest30 sync.Map peakRequest30 sync.Map
peakRequest60 sync.Map peakRequest60 sync.Map
lruCache *lru.Cache lruCache *lru.Cache
stats *Stats stats *Stats
router routing.Router router routing.Router
upstreamTransport *http.Transport
} }
// NewRequestProcessor returns new RequestProcessor. // NewRequestProcessor returns new RequestProcessor.
@ -41,9 +43,22 @@ func NewRequestProcessor() (*RequestProcessor, error) {
return nil, err return nil, err
} }
dialer := &net.Dialer{
Timeout: uplinkTimeout * time.Second,
KeepAlive: uplinkTimeout * time.Second,
DualStack: true,
}
transport := &http.Transport{
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, network, uplinkSrvAddr)
},
}
rp := &RequestProcessor{ rp := &RequestProcessor{
lruCache: lruCache, lruCache: lruCache,
stats: NewStats(), stats: NewStats(),
upstreamTransport: transport,
} }
// Initialize routes. // Initialize routes.
@ -80,7 +95,7 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader
if dontCache(r) { if dontCache(r) {
rp.stats.Inc("uncached") rp.stats.Inc("uncached")
return get(r) return get(r, rp.upstreamTransport)
} }
cacheDigest := getCacheDigest(r) cacheDigest := getCacheDigest(r)
@ -127,7 +142,7 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader
rp.lruCache.Add(cacheDigest, responseWithHeader{InProgress: true}) rp.lruCache.Add(cacheDigest, responseWithHeader{InProgress: true})
response, err = get(r) response, err = get(r, rp.upstreamTransport)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,9 +156,11 @@ func (rp *RequestProcessor) ProcessRequest(r *http.Request) (*responseWithHeader
return response, nil return response, nil
} }
func get(req *http.Request) (*responseWithHeader, error) { func get(req *http.Request, transport *http.Transport) (*responseWithHeader, error) {
client := &http.Client{} client := &http.Client{
Transport: transport,
}
queryURL := fmt.Sprintf("http://%s%s", req.Host, req.RequestURI) queryURL := fmt.Sprintf("http://%s%s", req.Host, req.RequestURI)

View file

@ -1,12 +1,10 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"time" "time"
@ -35,18 +33,6 @@ var plainTextAgents = []string{
"xh", "xh",
} }
func init() {
dialer := &net.Dialer{
Timeout: uplinkTimeout * time.Second,
KeepAlive: uplinkTimeout * time.Second,
DualStack: true,
}
http.DefaultTransport.(*http.Transport).DialContext = func(ctx context.Context, network, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, network, uplinkSrvAddr)
}
}
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {
for k, vv := range src { for k, vv := range src {
for _, v := range vv { for _, v := range vv {