mirror of
synced 2025-02-16 20:28:27 +00:00
452 lines
12 KiB
452 lines
12 KiB
package intercept
import (
var (
ErrRequestAborted = errors.New("intercept: request was aborted")
ErrRequestNotFound = errors.New("intercept: request not found")
ErrRequestDone = errors.New("intercept: request is done")
ErrResponseNotFound = errors.New("intercept: response not found")
type contextKey int
const interceptResponseKey contextKey = 0
// Request represents a server received HTTP request, alongside a channel for sending a modified version of it to the
// routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Request struct {
req *http.Request
ch chan<- *http.Request
done <-chan struct{}
// Response represents an HTTP response from a proxied request, alongside a channel for sending a modified version of it
// to the routine that's awaiting it. Also contains a channel for receiving a cancellation signal.
type Response struct {
res *http.Response
ch chan<- *http.Response
done <-chan struct{}
type Item struct {
Request *http.Request
Response *http.Response
type Service struct {
reqMu *sync.RWMutex
resMu *sync.RWMutex
requests map[ulid.ULID]Request
responses map[ulid.ULID]Response
logger log.Logger
requestsEnabled bool
responsesEnabled bool
reqFilter filter.Expression
resFilter filter.Expression
type Config struct {
Logger log.Logger
RequestsEnabled bool
ResponsesEnabled bool
RequestFilter filter.Expression
ResponseFilter filter.Expression
// RequestIDs implements sort.Interface.
type RequestIDs []ulid.ULID
func NewService(cfg Config) *Service {
s := &Service{
reqMu: &sync.RWMutex{},
resMu: &sync.RWMutex{},
requests: make(map[ulid.ULID]Request),
responses: make(map[ulid.ULID]Response),
logger: cfg.Logger,
requestsEnabled: cfg.RequestsEnabled,
responsesEnabled: cfg.ResponsesEnabled,
reqFilter: cfg.RequestFilter,
resFilter: cfg.ResponseFilter,
if s.logger == nil {
s.logger = log.NewNopLogger()
return s
// RequestModifier is a proxy.RequestModifyMiddleware for intercepting HTTP requests.
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
// This is a blocking operation, that gets unblocked when either a modified request is returned or an error
// (typically `context.Canceled`).
modifiedReq, err := svc.InterceptRequest(req.Context(), req)
switch {
case errors.Is(err, ErrRequestAborted):
svc.logger.Debugw("Stopping intercept, request was aborted.")
// Prevent further processing by replacing req.Context with a cancelled context value.
// This will cause the http.Roundtripper in the `proxy` package to
// handle this request as an error.
ctx, cancel := context.WithCancel(context.Background())
*req = *req.WithContext(ctx)
case errors.Is(err, context.Canceled):
svc.logger.Debugw("Stopping intercept, context was cancelled.")
case err != nil:
svc.logger.Errorw("Failed to intercept request.",
"error", err)
*req = *modifiedReq
// InterceptRequest adds an HTTP request to an array of pending intercepted requests, alongside channels used for
// sending a cancellation signal and receiving a modified request. It's safe for concurrent use.
func (svc *Service) InterceptRequest(ctx context.Context, req *http.Request) (*http.Request, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
return req, nil
if !svc.requestsEnabled {
// If request intercept is disabled, return the incoming request as-is.
svc.logger.Debugw("Bypassed request interception: feature disabled.")
return req, nil
if svc.reqFilter != nil {
match, err := MatchRequestFilter(req, svc.reqFilter)
if err != nil {
return nil, fmt.Errorf("intercept: failed to match request rules for request (id: %v): %w",
reqID.String(), err,
if !match {
svc.logger.Debugw("Bypassed request interception: request rules don't match.")
return req, nil
ch := make(chan *http.Request)
done := make(chan struct{})
svc.requests[reqID] = Request{
req: req,
ch: ch,
done: done,
// Whatever happens next (modified request returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the request should be removed from the requests queue.
defer func() {
defer svc.reqMu.Unlock()
delete(svc.requests, reqID)
select {
case modReq := <-ch:
if modReq == nil {
return nil, ErrRequestAborted
return modReq, nil
case <-ctx.Done():
return nil, ctx.Err()
// ModifyRequest sends a modified HTTP request to the related channel, or returns ErrRequestDone when the request was
// cancelled. It's safe for concurrent use.
func (svc *Service) ModifyRequest(reqID ulid.ULID, modReq *http.Request, modifyResponse *bool) error {
req, ok := svc.requests[reqID]
if !ok {
return ErrRequestNotFound
*modReq = *modReq.WithContext(req.req.Context())
if modifyResponse != nil {
*modReq = *modReq.WithContext(WithInterceptResponse(modReq.Context(), *modifyResponse))
select {
case <-req.done:
return ErrRequestDone
case req.ch <- modReq:
return nil
// CancelRequest ensures an intercepted request is dropped.
func (svc *Service) CancelRequest(reqID ulid.ULID) error {
return svc.ModifyRequest(reqID, nil, nil)
func (svc *Service) ClearRequests() {
defer svc.reqMu.Unlock()
for _, req := range svc.requests {
select {
case <-req.done:
case req.ch <- nil:
func (svc *Service) ClearResponses() {
defer svc.resMu.Unlock()
for _, res := range svc.responses {
select {
case <-res.done:
case res.ch <- nil:
// Items returns a list of pending items (requests and responses). It's safe for concurrent use.
func (svc *Service) Items() []Item {
defer svc.reqMu.RUnlock()
defer svc.resMu.RUnlock()
reqIDs := make([]ulid.ULID, 0, len(svc.requests)+len(svc.responses))
for id := range svc.requests {
reqIDs = append(reqIDs, id)
for id := range svc.responses {
reqIDs = append(reqIDs, id)
items := make([]Item, len(reqIDs))
for i, id := range reqIDs {
item := Item{}
if req, ok := svc.requests[id]; ok {
item.Request = req.req
if res, ok := svc.responses[id]; ok {
item.Response = res.res
items[i] = item
return items
func (svc *Service) UpdateSettings(settings Settings) {
// When updating from requests `enabled` -> `disabled`, clear any pending reqs.
if svc.requestsEnabled && !settings.RequestsEnabled {
// When updating from responses `enabled` -> `disabled`, clear any pending responses.
if svc.responsesEnabled && !settings.ResponsesEnabled {
svc.requestsEnabled = settings.RequestsEnabled
svc.responsesEnabled = settings.ResponsesEnabled
svc.reqFilter = settings.RequestFilter
svc.resFilter = settings.ResponseFilter
// ItemByID returns an intercepted item (request and possible response) by ID. It's safe for concurrent use.
func (svc *Service) ItemByID(id ulid.ULID) (Item, error) {
defer svc.reqMu.RUnlock()
defer svc.resMu.RUnlock()
item := Item{}
found := false
if req, ok := svc.requests[id]; ok {
item.Request = req.req
found = true
if res, ok := svc.responses[id]; ok {
item.Response = res.res
found = true
if !found {
return Item{}, ErrRequestNotFound
return item, nil
func (ids RequestIDs) Len() int {
return len(ids)
func (ids RequestIDs) Less(i, j int) bool {
return ids[i].Compare(ids[j]) == -1
func (ids RequestIDs) Swap(i, j int) {
ids[i], ids[j] = ids[j], ids[i]
func WithInterceptResponse(ctx context.Context, value bool) context.Context {
return context.WithValue(ctx, interceptResponseKey, value)
func ShouldInterceptResponseFromContext(ctx context.Context) (bool, bool) {
shouldIntercept, ok := ctx.Value(interceptResponseKey).(bool)
return shouldIntercept, ok
// ResponseModifier is a proxy.ResponseModifyMiddleware for intercepting HTTP responses.
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
// This is a blocking operation, that gets unblocked when either a modified response is returned or an error.
modifiedRes, err := svc.InterceptResponse(res.Request.Context(), res)
if err != nil {
return fmt.Errorf("failed to intercept response: %w", err)
*res = *modifiedRes
return next(res)
// InterceptResponse adds an HTTP response to an array of pending intercepted responses, alongside channels used for
// sending a cancellation signal and receiving a modified response. It's safe for concurrent use.
func (svc *Service) InterceptResponse(ctx context.Context, res *http.Response) (*http.Response, error) {
reqID, ok := proxy.RequestIDFromContext(ctx)
if !ok {
svc.logger.Errorw("Failed to intercept: context doesn't have an ID.")
return res, nil
shouldIntercept, ok := ShouldInterceptResponseFromContext(ctx)
if ok && !shouldIntercept {
// If the related request explicitly disabled response intercept, return the response as-is.
svc.logger.Debugw("Bypassed response interception: related request explicitly disabled response intercept.")
return res, nil
// If global response intercept is disabled and interception is *not* explicitly enabled for this response: bypass.
if !svc.responsesEnabled && !(ok && shouldIntercept) {
svc.logger.Debugw("Bypassed response interception: feature disabled.")
return res, nil
if svc.resFilter != nil {
match, err := MatchResponseFilter(res, svc.resFilter)
if err != nil {
return nil, fmt.Errorf("intercept: failed to match response rules for response (id: %v): %w",
reqID.String(), err,
if !match {
svc.logger.Debugw("Bypassed response interception: response rules don't match.")
return res, nil
ch := make(chan *http.Response)
done := make(chan struct{})
svc.responses[reqID] = Response{
res: res,
ch: ch,
done: done,
// Whatever happens next (modified response returned, or a context cancelled error), any blocked channel senders
// should be unblocked, and the response should be removed from the responses queue.
defer func() {
defer svc.resMu.Unlock()
delete(svc.responses, reqID)
select {
case modRes := <-ch:
if modRes == nil {
return nil, ErrRequestAborted
return modRes, nil
case <-ctx.Done():
return nil, ctx.Err()
// ModifyResponse sends a modified HTTP response to the related channel, or returns ErrRequestDone when the related
// request was cancelled. It's safe for concurrent use.
func (svc *Service) ModifyResponse(reqID ulid.ULID, modRes *http.Response) error {
res, ok := svc.responses[reqID]
if !ok {
return ErrRequestNotFound
if modRes != nil {
modRes.Request = res.res.Request
select {
case <-res.done:
return ErrRequestDone
case res.ch <- modRes:
return nil
// CancelResponse ensures an intercepted response is dropped.
func (svc *Service) CancelResponse(reqID ulid.ULID) error {
return svc.ModifyResponse(reqID, nil)