Remove capturing the cancel callstack in the context package (#1595)

* Fix race condition in context package

* Remove capturing the cancel callstack
This commit is contained in:
Miccah 2023-08-01 21:34:00 -05:00 committed by GitHub
parent 0ad46381d9
commit 160fd830dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 86 deletions

View file

@ -2,10 +2,7 @@ package context
import (
"context"
"fmt"
"os"
"runtime/debug"
"sync"
"time"
"github.com/go-logr/logr"
@ -37,9 +34,7 @@ type CancelFunc = context.CancelFunc
type logCtx struct {
// Embed context.Context to get all methods for free.
context.Context
log logr.Logger
err *error
errLock *sync.Mutex
log logr.Logger
}
// Logger returns a structured logger.
@ -47,13 +42,6 @@ func (l logCtx) Logger() logr.Logger {
return l.log
}
func (l logCtx) Err() error {
if l.err != nil && *l.err != nil {
return *l.err
}
return l.Context.Err()
}
// Background returns context.Background with a default logger.
func Background() Context {
return logCtx{
@ -77,7 +65,7 @@ func WithCancel(parent Context) (Context, context.CancelFunc) {
log: parent.Logger(),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}
// WithDeadline returns context.WithDeadline with the log object propagated and
@ -88,7 +76,7 @@ func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
log: parent.Logger().WithValues("deadline", d),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}
// WithTimeout returns context.WithTimeout with the log object propagated and
@ -99,7 +87,7 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, context.Cancel
log: parent.Logger().WithValues("timeout", timeout),
Context: ctx,
}
return captureCancelCallstack(lCtx, cancel)
return lCtx, cancel
}
// WithValue returns context.WithValue with the log object propagated and
@ -150,29 +138,3 @@ func AddLogger(parent context.Context) Context {
func SetDefaultLogger(l logr.Logger) {
defaultLogger = l
}
// captureCancelCallstack is a helper function to capture the callstack where
// the cancel function was first called.
func captureCancelCallstack(ctx logCtx, f context.CancelFunc) (Context, context.CancelFunc) {
if ctx.err == nil {
var err error
ctx.err = &err
ctx.errLock = &sync.Mutex{}
}
return ctx, func() {
ctx.errLock.Lock()
defer ctx.errLock.Unlock()
// We must check Err() before calling f() since f() sets the error.
// If there's already an error, do nothing special.
if ctx.Err() != nil {
f()
return
}
f()
// Set the error with the stacktrace if the err pointer is non-nil.
*ctx.err = fmt.Errorf(
"%w (canceled at %v\n%s)",
ctx.Err(), time.Now(), string(debug.Stack()),
)
}
}

View file

@ -170,51 +170,10 @@ func TestDefaultLogger(t *testing.T) {
ctx.Logger().Info("this shouldn't panic")
}
func TestErrCallstack(t *testing.T) {
c, cancel := WithCancel(Background())
ctx := c.(logCtx)
cancel()
select {
case <-ctx.Done():
assert.Contains(t, ctx.Err().Error(), "TestErrCallstack")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}
func TestErrCallstackTimeout(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
defer cancel()
select {
case <-ctx.Done():
// Deadline exceeded errors will not have a callstack from the cancel
// function.
assert.NotContains(t, ctx.Err().Error(), "TestErrCallstackTimeout")
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
}
func TestErrCallstackTimeoutCancel(t *testing.T) {
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(1 * time.Second):
assert.Fail(t, "context should be done")
}
// Calling cancel after deadline exceeded should not overwrite the original
// error.
cancel()
assert.Equal(t, err, ctx.Err())
}
func TestRace(t *testing.T) {
_, cancel := WithCancel(Background())
ctx, cancel := WithCancel(Background())
go cancel()
go func() { _ = ctx.Err() }()
cancel()
_ = ctx.Err()
}