mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
Remove capturing the cancel callstack in the context package (#1595)
* Fix race condition in context package * Remove capturing the cancel callstack
This commit is contained in:
parent
0ad46381d9
commit
160fd830dd
2 changed files with 7 additions and 86 deletions
|
@ -2,10 +2,7 @@ package context
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
|
@ -37,9 +34,7 @@ type CancelFunc = context.CancelFunc
|
|||
type logCtx struct {
|
||||
// Embed context.Context to get all methods for free.
|
||||
context.Context
|
||||
log logr.Logger
|
||||
err *error
|
||||
errLock *sync.Mutex
|
||||
log logr.Logger
|
||||
}
|
||||
|
||||
// Logger returns a structured logger.
|
||||
|
@ -47,13 +42,6 @@ func (l logCtx) Logger() logr.Logger {
|
|||
return l.log
|
||||
}
|
||||
|
||||
func (l logCtx) Err() error {
|
||||
if l.err != nil && *l.err != nil {
|
||||
return *l.err
|
||||
}
|
||||
return l.Context.Err()
|
||||
}
|
||||
|
||||
// Background returns context.Background with a default logger.
|
||||
func Background() Context {
|
||||
return logCtx{
|
||||
|
@ -77,7 +65,7 @@ func WithCancel(parent Context) (Context, context.CancelFunc) {
|
|||
log: parent.Logger(),
|
||||
Context: ctx,
|
||||
}
|
||||
return captureCancelCallstack(lCtx, cancel)
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithDeadline returns context.WithDeadline with the log object propagated and
|
||||
|
@ -88,7 +76,7 @@ func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
|
|||
log: parent.Logger().WithValues("deadline", d),
|
||||
Context: ctx,
|
||||
}
|
||||
return captureCancelCallstack(lCtx, cancel)
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithTimeout returns context.WithTimeout with the log object propagated and
|
||||
|
@ -99,7 +87,7 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, context.Cancel
|
|||
log: parent.Logger().WithValues("timeout", timeout),
|
||||
Context: ctx,
|
||||
}
|
||||
return captureCancelCallstack(lCtx, cancel)
|
||||
return lCtx, cancel
|
||||
}
|
||||
|
||||
// WithValue returns context.WithValue with the log object propagated and
|
||||
|
@ -150,29 +138,3 @@ func AddLogger(parent context.Context) Context {
|
|||
func SetDefaultLogger(l logr.Logger) {
|
||||
defaultLogger = l
|
||||
}
|
||||
|
||||
// captureCancelCallstack is a helper function to capture the callstack where
|
||||
// the cancel function was first called.
|
||||
func captureCancelCallstack(ctx logCtx, f context.CancelFunc) (Context, context.CancelFunc) {
|
||||
if ctx.err == nil {
|
||||
var err error
|
||||
ctx.err = &err
|
||||
ctx.errLock = &sync.Mutex{}
|
||||
}
|
||||
return ctx, func() {
|
||||
ctx.errLock.Lock()
|
||||
defer ctx.errLock.Unlock()
|
||||
// We must check Err() before calling f() since f() sets the error.
|
||||
// If there's already an error, do nothing special.
|
||||
if ctx.Err() != nil {
|
||||
f()
|
||||
return
|
||||
}
|
||||
f()
|
||||
// Set the error with the stacktrace if the err pointer is non-nil.
|
||||
*ctx.err = fmt.Errorf(
|
||||
"%w (canceled at %v\n%s)",
|
||||
ctx.Err(), time.Now(), string(debug.Stack()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -170,51 +170,10 @@ func TestDefaultLogger(t *testing.T) {
|
|||
ctx.Logger().Info("this shouldn't panic")
|
||||
}
|
||||
|
||||
func TestErrCallstack(t *testing.T) {
|
||||
c, cancel := WithCancel(Background())
|
||||
ctx := c.(logCtx)
|
||||
cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
assert.Contains(t, ctx.Err().Error(), "TestErrCallstack")
|
||||
case <-time.After(1 * time.Second):
|
||||
assert.Fail(t, "context should be done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrCallstackTimeout(t *testing.T) {
|
||||
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Deadline exceeded errors will not have a callstack from the cancel
|
||||
// function.
|
||||
assert.NotContains(t, ctx.Err().Error(), "TestErrCallstackTimeout")
|
||||
case <-time.After(1 * time.Second):
|
||||
assert.Fail(t, "context should be done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrCallstackTimeoutCancel(t *testing.T) {
|
||||
ctx, cancel := WithTimeout(Background(), 10*time.Millisecond)
|
||||
|
||||
var err error
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-time.After(1 * time.Second):
|
||||
assert.Fail(t, "context should be done")
|
||||
}
|
||||
|
||||
// Calling cancel after deadline exceeded should not overwrite the original
|
||||
// error.
|
||||
cancel()
|
||||
assert.Equal(t, err, ctx.Err())
|
||||
}
|
||||
|
||||
func TestRace(t *testing.T) {
|
||||
_, cancel := WithCancel(Background())
|
||||
ctx, cancel := WithCancel(Background())
|
||||
go cancel()
|
||||
go func() { _ = ctx.Err() }()
|
||||
cancel()
|
||||
_ = ctx.Err()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue