chore: propagate log context to handlers (#2191)

This commit is contained in:
Richard Gomez 2023-12-10 13:30:11 -05:00 committed by GitHub
parent 6c5fc2f212
commit d1a2d9e832
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 42 deletions

View file

@ -66,10 +66,10 @@ func SetArchiveMaxTimeout(timeout time.Duration) {
}
// FromFile extracts the files from an archive.
func (a *Archive) FromFile(originalCtx context.Context, data io.Reader) chan []byte {
func (a *Archive) FromFile(originalCtx logContext.Context, data io.Reader) chan []byte {
archiveChan := make(chan []byte, defaultBufferSize)
go func() {
ctx, cancel := context.WithTimeout(originalCtx, maxTimeout)
ctx, cancel := logContext.WithTimeout(originalCtx, maxTimeout)
logger := logContext.AddLogger(ctx).Logger()
defer cancel()
defer close(archiveChan)
@ -92,7 +92,7 @@ type decompressorInfo struct {
}
// openArchive takes a reader and extracts the contents up to the maximum depth.
func (a *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan []byte) error {
func (a *Archive) openArchive(ctx logContext.Context, depth int, reader io.Reader, archiveChan chan []byte) error {
if depth >= maxDepth {
return fmt.Errorf(errMaxArchiveDepthReached)
}
@ -112,19 +112,18 @@ func (a *Archive) openArchive(ctx context.Context, depth int, reader io.Reader,
return a.handleDecompressor(ctx, info)
case archiver.Extractor:
return archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
return archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
default:
return fmt.Errorf("unknown archive type: %s", format.Name())
}
}
func (a *Archive) handleNonArchiveContent(ctx context.Context, reader io.Reader, archiveChan chan []byte) error {
aCtx := logContext.AddLogger(ctx)
func (a *Archive) handleNonArchiveContent(ctx logContext.Context, reader io.Reader, archiveChan chan []byte) error {
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(aCtx, reader)
chunkResChan := chunkReader(ctx, reader)
for data := range chunkResChan {
if err := data.Error(); err != nil {
aCtx.Logger().Error(err, "error reading chunk")
ctx.Logger().Error(err, "error reading chunk")
continue
}
if err := common.CancellableWrite(ctx, archiveChan, data.Bytes()); err != nil {
@ -134,7 +133,7 @@ func (a *Archive) handleNonArchiveContent(ctx context.Context, reader io.Reader,
return nil
}
func (a *Archive) handleDecompressor(ctx context.Context, info decompressorInfo) error {
func (a *Archive) handleDecompressor(ctx logContext.Context, info decompressorInfo) error {
compReader, err := info.archiver.OpenReader(info.reader)
if err != nil {
return err
@ -147,7 +146,7 @@ func (a *Archive) handleDecompressor(ctx context.Context, info decompressorInfo)
}
// IsFiletype returns true if the provided reader is an archive.
func (a *Archive) IsFiletype(_ context.Context, reader io.Reader) (io.Reader, bool) {
func (a *Archive) IsFiletype(_ logContext.Context, reader io.Reader) (io.Reader, bool) {
format, readerB, err := archiver.Identify("", reader)
if err != nil {
return readerB, false
@ -165,8 +164,8 @@ func (a *Archive) IsFiletype(_ context.Context, reader io.Reader) (io.Reader, bo
// extractorHandler is applied to each file in an archiver.Extractor file.
func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context, archiver.File) error {
return func(ctx context.Context, f archiver.File) error {
logger := logContext.AddLogger(ctx).Logger()
logger.V(5).Info("Handling extracted file.", "filename", f.Name())
lCtx := logContext.AddLogger(ctx)
lCtx.Logger().V(5).Info("Handling extracted file.", "filename", f.Name())
depth := 0
if ctxDepth, ok := ctx.Value(depthKey).(int); ok {
depth = ctxDepth
@ -177,16 +176,16 @@ func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context
return err
}
if common.SkipFile(f.Name()) {
logger.V(5).Info("skipping file", "filename", f.Name())
lCtx.Logger().V(5).Info("skipping file", "filename", f.Name())
return nil
}
fileBytes, err := a.ReadToMax(ctx, fReader)
fileBytes, err := a.ReadToMax(lCtx, fReader)
if err != nil {
return err
}
err = a.openArchive(ctx, depth, bytes.NewReader(fileBytes), archiveChan)
err = a.openArchive(lCtx, depth, bytes.NewReader(fileBytes), archiveChan)
if err != nil {
return err
}
@ -195,12 +194,11 @@ func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context
}
// ReadToMax reads up to the max size.
func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, err error) {
func (a *Archive) ReadToMax(ctx logContext.Context, reader io.Reader) (data []byte, err error) {
// Archiver v4 is in alpha and using an experimental version of
// rardecode. There is a bug somewhere with rar decoder format 29
// that can lead to a panic. An issue is open in rardecode repo
// https://github.com/nwaples/rardecode/issues/30.
logger := logContext.AddLogger(ctx).Logger()
defer func() {
if r := recover(); r != nil {
// Return an error from ReadToMax.
@ -209,7 +207,7 @@ func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
} else {
err = fmt.Errorf("panic occurred: %v", r)
}
logger.Error(err, "Panic occurred when reading archive")
ctx.Logger().Error(err, "Panic occurred when reading archive")
}
}()
@ -231,7 +229,7 @@ func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
}
if fileContent.Len() == maxSize {
logger.V(2).Info("Max archive size reached.")
ctx.Logger().V(2).Info("Max archive size reached.")
}
return fileContent.Bytes(), nil
@ -483,7 +481,7 @@ type tempEnv struct {
// createTempEnv creates a temporary file and a temporary directory for extracting archives.
// The caller is responsible for removing these temporary resources
// (both the file and directory) when they are no longer needed.
func (a *Archive) createTempEnv(ctx context.Context, file io.Reader) (tempEnv, error) {
func (a *Archive) createTempEnv(ctx logContext.Context, file io.Reader) (tempEnv, error) {
tempFile, err := os.CreateTemp("", "tmp")
if err != nil {
return tempEnv{}, fmt.Errorf("unable to create temporary file: %w", err)

View file

@ -86,7 +86,7 @@ func TestArchiveHandler(t *testing.T) {
if err != nil {
t.Errorf("error creating reusable reader: %s", err)
}
archiveChan := archive.FromFile(context.Background(), newReader)
archiveChan := archive.FromFile(logContext.Background(), newReader)
count := 0
re := regexp.MustCompile(testCase.matchString)
@ -110,7 +110,7 @@ func TestHandleFile(t *testing.T) {
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 2)}
// Context cancels the operation.
canceledCtx, cancel := context.WithCancel(context.Background())
canceledCtx, cancel := logContext.WithCancel(logContext.Background())
cancel()
assert.False(t, HandleFile(canceledCtx, strings.NewReader("file"), &sources.Chunk{}, reporter))
@ -125,7 +125,7 @@ func TestHandleFile(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 0, len(reporter.Ch))
assert.True(t, HandleFile(context.Background(), reader, &sources.Chunk{}, reporter))
assert.True(t, HandleFile(logContext.Background(), reader, &sources.Chunk{}, reporter))
assert.Equal(t, 1, len(reporter.Ch))
}
@ -157,7 +157,7 @@ func TestReadToMax(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := bytes.NewReader(tt.input)
output, err := a.ReadToMax(context.Background(), reader)
output, err := a.ReadToMax(logContext.Background(), reader)
assert.Nil(t, err)
assert.Equal(t, tt.expected, output)
@ -173,7 +173,7 @@ func BenchmarkReadToMax(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
_, _ = a.ReadToMax(context.Background(), reader)
_, _ = a.ReadToMax(logContext.Background(), reader)
b.StopTimer()
_, _ = reader.Seek(0, 0) // Reset the reader position.
@ -204,7 +204,7 @@ func TestExtractTarContent(t *testing.T) {
assert.Nil(t, err)
defer file.Close()
ctx := context.Background()
ctx := logContext.Background()
chunkCh := make(chan *sources.Chunk)
go func() {

View file

@ -1,7 +1,6 @@
package handlers
import (
"context"
"io"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
@ -26,8 +25,8 @@ type SpecializedHandler interface {
}
type Handler interface {
FromFile(context.Context, io.Reader) chan []byte
IsFiletype(context.Context, io.Reader) (io.Reader, bool)
FromFile(logContext.Context, io.Reader) chan []byte
IsFiletype(logContext.Context, io.Reader) (io.Reader, bool)
New()
}
@ -37,8 +36,7 @@ type Handler interface {
// packages them in the provided chunk skeleton, and reports them to the chunk reporter.
// The function returns true if processing was successful and false otherwise.
// Context is used for cancellation, and the caller is responsible for canceling it if needed.
func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, reporter sources.ChunkReporter) bool {
aCtx := logContext.AddLogger(ctx)
func HandleFile(ctx logContext.Context, file io.Reader, chunkSkel *sources.Chunk, reporter sources.ChunkReporter) bool {
for _, h := range DefaultHandlers() {
h.New()
@ -47,11 +45,11 @@ func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, r
// an io.MultiReader, which is used by the SpecializedHandler.
reReader, err := diskbufferreader.New(file)
if err != nil {
aCtx.Logger().Error(err, "error creating reusable reader")
ctx.Logger().Error(err, "error creating reusable reader")
return false
}
if success := processHandler(aCtx, h, reReader, chunkSkel, reporter); success {
if success := processHandler(ctx, h, reReader, chunkSkel, reporter); success {
return true
}
}
@ -85,7 +83,7 @@ func processHandler(ctx logContext.Context, h Handler, reReader *diskbufferreade
return handleChunks(ctx, h.FromFile(ctx, reReader), chunkSkel, reporter)
}
func handleChunks(ctx context.Context, handlerChan chan []byte, chunkSkel *sources.Chunk, reporter sources.ChunkReporter) bool {
func handleChunks(ctx logContext.Context, handlerChan chan []byte, chunkSkel *sources.Chunk, reporter sources.ChunkReporter) bool {
for {
select {
case data, open := <-handlerChan:
@ -94,7 +92,7 @@ func handleChunks(ctx context.Context, handlerChan chan []byte, chunkSkel *sourc
}
chunk := *chunkSkel
chunk.Data = data
if err := reporter.ChunkOk(logContext.AddLogger(ctx), chunk); err != nil {
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return false
}
case <-ctx.Done():

View file

@ -994,10 +994,11 @@ func getSafeRemoteURL(repo *git.Repository, preferred string) string {
}
func handleBinary(ctx context.Context, gitDir string, reporter sources.ChunkReporter, chunkSkel *sources.Chunk, commitHash plumbing.Hash, path string) error {
ctx.Logger().V(5).Info("handling binary file", "path", path)
fileCtx := context.WithValues(ctx, "commit", commitHash.String(), "path", path)
fileCtx.Logger().V(5).Info("handling binary file")
if common.SkipFile(path) {
ctx.Logger().V(5).Info("skipping binary file", "path", path)
fileCtx.Logger().V(5).Info("skipping binary file")
return nil
}
@ -1043,7 +1044,7 @@ func handleBinary(ctx context.Context, gitDir string, reporter sources.ChunkRepo
}
if fileContent.Len() == maxSize {
ctx.Logger().V(2).Info("Max archive size reached.", "path", path)
fileCtx.Logger().V(2).Info("Max archive size reached.")
}
reader, err := diskbufferreader.New(&fileContent)
@ -1052,25 +1053,25 @@ func handleBinary(ctx context.Context, gitDir string, reporter sources.ChunkRepo
}
defer reader.Close()
if handlers.HandleFile(ctx, reader, chunkSkel, reporter) {
if handlers.HandleFile(fileCtx, reader, chunkSkel, reporter) {
return nil
}
ctx.Logger().V(1).Info("binary file not handled, chunking raw", "path", path)
fileCtx.Logger().V(1).Info("binary file not handled, chunking raw")
if err := reader.Reset(); err != nil {
return err
}
reader.Stop()
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(ctx, reader)
chunkResChan := chunkReader(fileCtx, reader)
for data := range chunkResChan {
chunk := *chunkSkel
chunk.Data = data.Bytes()
if err := data.Error(); err != nil {
return err
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
if err := reporter.ChunkOk(fileCtx, chunk); err != nil {
return err
}
}