diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index e8f571698..6e7ee4efd 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -91,6 +91,10 @@ func (a *Archive) FromFile(originalCtx logContext.Context, data io.Reader) chan // openArchive takes a reader and extracts the contents up to the maximum depth. func (a *Archive) openArchive(ctx logContext.Context, depth int, reader io.Reader, archiveChan chan []byte) error { + if common.IsDone(ctx) { + return ctx.Err() + } + if depth >= maxDepth { return fmt.Errorf(errMaxArchiveDepthReached) } @@ -183,6 +187,11 @@ func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context return func(ctx context.Context, f archiver.File) error { lCtx := logContext.AddLogger(ctx) lCtx.Logger().V(5).Info("Handling extracted file.", "filename", f.Name()) + + if common.IsDone(ctx) { + return ctx.Err() + } + depth := 0 if ctxDepth, ok := ctx.Value(depthKey).(int); ok { depth = ctxDepth diff --git a/pkg/sources/chunker.go b/pkg/sources/chunker.go index 4fb0f321e..47e9e500d 100644 --- a/pkg/sources/chunker.go +++ b/pkg/sources/chunker.go @@ -144,15 +144,22 @@ func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConf // If there is an error other than EOF, or if we have read some bytes, send the chunk. // io.ReadFull will only return io.EOF when n == 0. - if isErrAndNotEOF(err) { + switch { + case isErrAndNotEOF(err): ctx.Logger().Error(err, "error reading chunk") chunkRes.err = err - chunkResultChan <- chunkRes - } else if n > 0 { - chunkResultChan <- chunkRes + case n > 0: + chunkRes.err = nil + default: + return + } + + select { + case <-ctx.Done(): + return + case chunkResultChan <- chunkRes: } - // Return on any type of error. if err != nil { return } diff --git a/pkg/sources/chunker_test.go b/pkg/sources/chunker_test.go index aa350f7f0..a0017ba23 100644 --- a/pkg/sources/chunker_test.go +++ b/pkg/sources/chunker_test.go @@ -3,9 +3,12 @@ package sources import ( "bytes" "io" + "math/rand" + "runtime" "strings" "testing" "testing/iotest" + "time" "github.com/stretchr/testify/assert" diskbufferreader "github.com/trufflesecurity/disk-buffer-reader" @@ -217,3 +220,35 @@ func TestFlakyChunkReader(t *testing.T) { assert.NoError(t, chunk.Error()) assert.Equal(t, a+b, string(chunk.Bytes())) } + +func TestReadInChunksWithCancellation(t *testing.T) { + largeData := strings.Repeat("large test data ", 1024*1024) // Large data string. + + for i := 0; i < 10; i++ { + initialGoroutines := runtime.NumGoroutine() + + for j := 0; j < 5; j++ { // Call readInChunks multiple times + ctx, cancel := context.WithCancel(context.Background()) + + reader := strings.NewReader(largeData) + chunkReader := NewChunkReader() + + chunkChan := chunkReader(ctx, reader) + + if rand.Intn(2) == 0 { // Randomly decide to cancel the context + cancel() + } else { + for range chunkChan { + } + } + } + + // Allow for goroutine finalization. + time.Sleep(time.Millisecond * 100) + + // Check for goroutine leaks. + if runtime.NumGoroutine() > initialGoroutines { + t.Error("Potential goroutine leak detected") + } + } +}