mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 15:14:38 +00:00
168 lines
4.4 KiB
Go
168 lines
4.4 KiB
Go
package context
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-logr/logr"
|
|
)
|
|
|
|
var (
|
|
// defaultLogger can be set via SetDefaultLogger.
|
|
defaultLogger logr.Logger = logr.Discard()
|
|
)
|
|
|
|
// Context wraps context.Context and includes an additional Logger() method.
|
|
type Context interface {
|
|
context.Context
|
|
Logger() logr.Logger
|
|
}
|
|
|
|
// CancelFunc is a type alias to context.CancelFunc to allow use as if they are
|
|
// the same types.
|
|
type CancelFunc = context.CancelFunc
|
|
|
|
// logCtx implements Context.
|
|
type logCtx struct {
|
|
// Embed context.Context to get all methods for free.
|
|
context.Context
|
|
log logr.Logger
|
|
err *error
|
|
errLock *sync.Mutex
|
|
}
|
|
|
|
// Logger returns a structured logger.
|
|
func (l logCtx) Logger() logr.Logger {
|
|
return l.log
|
|
}
|
|
|
|
func (l logCtx) Err() error {
|
|
if l.err != nil && *l.err != nil {
|
|
return *l.err
|
|
}
|
|
return l.Context.Err()
|
|
}
|
|
|
|
// Background returns context.Background with a default logger.
|
|
func Background() Context {
|
|
return logCtx{
|
|
log: defaultLogger,
|
|
Context: context.Background(),
|
|
}
|
|
}
|
|
|
|
// TODO returns context.TODO with a default logger.
|
|
func TODO() Context {
|
|
return logCtx{
|
|
log: defaultLogger,
|
|
Context: context.TODO(),
|
|
}
|
|
}
|
|
|
|
// WithCancel returns context.WithCancel with the log object propagated.
|
|
func WithCancel(parent Context) (Context, context.CancelFunc) {
|
|
ctx, cancel := context.WithCancel(parent)
|
|
lCtx := logCtx{
|
|
log: parent.Logger(),
|
|
Context: ctx,
|
|
}
|
|
return captureCancelCallstack(lCtx, cancel)
|
|
}
|
|
|
|
// WithDeadline returns context.WithDeadline with the log object propagated and
|
|
// the deadline added to the structured log values.
|
|
func WithDeadline(parent Context, d time.Time) (Context, context.CancelFunc) {
|
|
ctx, cancel := context.WithDeadline(parent, d)
|
|
lCtx := logCtx{
|
|
log: parent.Logger().WithValues("deadline", d),
|
|
Context: ctx,
|
|
}
|
|
return captureCancelCallstack(lCtx, cancel)
|
|
}
|
|
|
|
// WithTimeout returns context.WithTimeout with the log object propagated and
|
|
// the timeout added to the structured log values.
|
|
func WithTimeout(parent Context, timeout time.Duration) (Context, context.CancelFunc) {
|
|
ctx, cancel := context.WithTimeout(parent, timeout)
|
|
lCtx := logCtx{
|
|
log: parent.Logger().WithValues("timeout", timeout),
|
|
Context: ctx,
|
|
}
|
|
return captureCancelCallstack(lCtx, cancel)
|
|
}
|
|
|
|
// WithValue returns context.WithValue with the log object propagated and
|
|
// the value added to the structured log values (if the key is a string).
|
|
func WithValue(parent Context, key, val any) Context {
|
|
logger := parent.Logger()
|
|
if k, ok := key.(string); ok {
|
|
logger = logger.WithValues(k, val)
|
|
}
|
|
return logCtx{
|
|
log: logger,
|
|
Context: context.WithValue(parent, key, val),
|
|
}
|
|
}
|
|
|
|
// WithValues returns context.WithValue with the log object propagated and
|
|
// the values added to the structured log values (if the key is a string).
|
|
func WithValues(parent Context, keyAndVals ...any) Context {
|
|
ctx := parent
|
|
for i := 0; i < len(keyAndVals)-1; i += 2 {
|
|
ctx = WithValue(ctx, keyAndVals[i], keyAndVals[i+1])
|
|
}
|
|
return ctx
|
|
}
|
|
|
|
// WithLogger converts a context.Context into a Context by adding a logger.
|
|
func WithLogger(parent context.Context, logger logr.Logger) Context {
|
|
return logCtx{
|
|
log: logger,
|
|
Context: parent,
|
|
}
|
|
}
|
|
|
|
// AddLogger converts a context.Context into a Context. If the underlying type
|
|
// is already a Context, that will be returned, otherwise a default logger will
|
|
// be added.
|
|
func AddLogger(parent context.Context) Context {
|
|
if loggerCtx, ok := parent.(Context); ok {
|
|
return loggerCtx
|
|
}
|
|
return WithLogger(parent, defaultLogger)
|
|
}
|
|
|
|
// SetupDefaultLogger sets the package-level global default logger that will be
|
|
// used for Background and TODO contexts.
|
|
func SetDefaultLogger(l logr.Logger) {
|
|
defaultLogger = l
|
|
}
|
|
|
|
// captureCancelCallstack is a helper function to capture the callstack where
|
|
// the cancel function was first called.
|
|
func captureCancelCallstack(ctx logCtx, f context.CancelFunc) (Context, context.CancelFunc) {
|
|
if ctx.err == nil {
|
|
var err error
|
|
ctx.err = &err
|
|
ctx.errLock = &sync.Mutex{}
|
|
}
|
|
return ctx, func() {
|
|
ctx.errLock.Lock()
|
|
defer ctx.errLock.Unlock()
|
|
// We must check Err() before calling f() since f() sets the error.
|
|
// If there's already an error, do nothing special.
|
|
if ctx.Err() != nil {
|
|
f()
|
|
return
|
|
}
|
|
f()
|
|
// Set the error with the stacktrace if the err pointer is non-nil.
|
|
*ctx.err = fmt.Errorf(
|
|
"%w (canceled at %v\n%s)",
|
|
ctx.Err(), time.Now(), string(debug.Stack()),
|
|
)
|
|
}
|
|
}
|