mirror of
https://github.com/dstotijn/hetty
synced 2024-11-10 06:04:19 +00:00
Add tests for intercept.ModifyResponse
This commit is contained in:
parent
1a45ea36a4
commit
a9f6701145
1 changed files with 139 additions and 0 deletions
|
@ -129,3 +129,142 @@ func TestRequestModifier(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseModifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("modify response that's not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
|
||||
err := svc.ModifyResponse(reqID, nil)
|
||||
if !errors.Is(err, intercept.ErrRequestNotFound) {
|
||||
t.Fatalf("expected `intercept.ErrRequestNotFound`, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify response of request that's done", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(ctx)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
res.Header.Add("X-Foo", "foo")
|
||||
|
||||
var modErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
next := func(res *http.Response) error { return nil }
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
modErr = svc.ResponseModifier(next)(res)
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the res modifier goroutine to add `res` to the
|
||||
// array of intercepted responses.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
modRes := *res
|
||||
modRes.Header = make(http.Header)
|
||||
modRes.Header.Set("X-Foo", "bar")
|
||||
|
||||
err := svc.ModifyResponse(reqID, &modRes)
|
||||
if !errors.Is(err, intercept.ErrRequestDone) {
|
||||
t.Fatalf("expected `intercept.ErrRequestDone`, got: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if !errors.Is(modErr, context.Canceled) {
|
||||
t.Fatalf("expected `context.Canceled`, got: %v", modErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modify intercepted response", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com/foo", nil)
|
||||
req.Header.Set("X-Foo", "foo")
|
||||
|
||||
reqID := ulid.MustNew(ulid.Timestamp(time.Now()), ulidEntropy)
|
||||
*req = *req.WithContext(proxy.WithRequestID(req.Context(), reqID))
|
||||
|
||||
res := &http.Response{
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
res.Header.Add("X-Foo", "foo")
|
||||
|
||||
modRes := *res
|
||||
modRes.Header = make(http.Header)
|
||||
modRes.Header.Set("X-Foo", "bar")
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
svc := intercept.NewService(intercept.Config{
|
||||
Logger: logger.Sugar(),
|
||||
RequestsEnabled: false,
|
||||
ResponsesEnabled: true,
|
||||
})
|
||||
|
||||
var gotHeader string
|
||||
|
||||
var next proxy.ResponseModifyFunc = func(res *http.Response) error {
|
||||
gotHeader = res.Header.Get("X-Foo")
|
||||
return nil
|
||||
}
|
||||
|
||||
var modErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
modErr = svc.ResponseModifier(next)(res)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Wait shortly, to allow the res modifier goroutine to add `req` to the
|
||||
// array of intercepted reqs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := svc.ModifyResponse(reqID, &modRes)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if modErr != nil {
|
||||
t.Fatalf("unexpected error: %v", modErr)
|
||||
}
|
||||
|
||||
if exp := "bar"; exp != gotHeader {
|
||||
t.Fatalf("incorrect modified request header value (expected: %v, got: %v)", exp, gotHeader)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue