[feat] - Streamlined File Handling with BufferedReaderSeeker (#3041)

* Streaming file handling.

* cleanup

* update tests

* lint

* defer close on input io.ReadCloser's

* fix seek bug

* fix hanging

* clarify errors

* update

* address comments

* revert

* update

* address

* add check to prevent seek without buffering

* revet

* revert

* update comment to make buffer usage more clear
This commit is contained in:
ahrav 2024-07-17 13:52:18 -07:00 committed by GitHub
parent 77bef38793
commit f865482025
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 709 additions and 248 deletions

View file

@ -20,7 +20,6 @@ func TestHandleARFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newARHandler()
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

View file

@ -87,15 +87,14 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
return ErrMaxDepthReached
}
arReader := reader.BufferedFileReader
if reader.format == nil && depth > 0 {
return h.handleNonArchiveContent(ctx, arReader, archiveChan)
return h.handleNonArchiveContent(ctx, reader, archiveChan)
}
switch archive := reader.format.(type) {
case archiver.Decompressor:
// Decompress tha archive and feed the decompressed data back into the archive handler to extract any nested archives.
compReader, err := archive.OpenReader(arReader)
compReader, err := archive.OpenReader(reader)
if err != nil {
return fmt.Errorf("error opening decompressor with format: %s %w", reader.format.Name(), err)
}
@ -109,11 +108,10 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
return h.openArchive(ctx, depth+1, rdr, archiveChan)
case archiver.Extractor:
err := archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), arReader, nil, h.extractorHandler(archiveChan))
err := archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), reader, nil, h.extractorHandler(archiveChan))
if err != nil {
return fmt.Errorf("error extracting archive with format: %s: %w", reader.format.Name(), err)
}
@ -193,7 +191,6 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
h.metrics.incFilesProcessed()
h.metrics.observeFileSize(fileSize)

View file

@ -119,7 +119,6 @@ func TestOpenInvalidArchive(t *testing.T) {
rdr, err := newFileReader(io.NopCloser(reader))
assert.NoError(t, err)
defer rdr.Close()
archiveChan := make(chan []byte)

View file

@ -20,7 +20,6 @@ func TestHandleNonArchiveFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newDefaultHandler(defaultHandlerType)
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

View file

@ -9,7 +9,7 @@ import (
"github.com/mholt/archiver/v4"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/readers"
"github.com/trufflesecurity/trufflehog/v3/pkg/iobuf"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
@ -29,63 +29,58 @@ import (
// promotes a more cohesive and maintainable codebase. It also embeds a BufferedFileReader to provide efficient
// random access to the file content.
type fileReader struct {
format archiver.Format
mimeType mimeType
*readers.BufferedFileReader
format archiver.Format
mimeType mimeType
isGenericArchive bool
*iobuf.BufferedReadSeeker
}
var ErrEmptyReader = errors.New("reader is empty")
func newFileReader(r io.ReadCloser) (fileReader, error) {
defer r.Close()
// newFileReader creates a fileReader from an io.Reader, optionally using BufferedFileWriter for certain formats.
func newFileReader(r io.Reader) (fileReader, error) {
var reader fileReader
var (
reader fileReader
rdr *readers.BufferedFileReader
err error
)
rdr, err = readers.NewBufferedFileReader(r)
bufReader := iobuf.NewBufferedReaderSeeker(r)
mime, err := mimetype.DetectReader(bufReader)
if err != nil {
return reader, fmt.Errorf("error creating random access reader: %w", err)
return reader, fmt.Errorf("unable to detect MIME type: %w", err)
}
reader.BufferedFileReader = rdr
reader.mimeType = mimeType(mime.String())
// Ensure the reader is closed if an error occurs after the reader is created.
// During non-error conditions, the caller is responsible for closing the reader.
defer func() {
if err != nil && rdr != nil {
_ = rdr.Close()
}
}()
// Check if the reader is empty.
if rdr.Size() == 0 {
return reader, ErrEmptyReader
// Reset the reader to the beginning because DetectReader consumes the reader.
if _, err := bufReader.Seek(0, io.SeekStart); err != nil {
return reader, fmt.Errorf("error resetting reader after MIME detection: %w", err)
}
format, arReader, err := archiver.Identify("", rdr)
format, _, err := archiver.Identify("", bufReader)
switch {
case err == nil: // Archive detected
case err == nil:
reader.isGenericArchive = true
reader.mimeType = mimeType(format.Name())
reader.format = format
case errors.Is(err, archiver.ErrNoMatch):
// Not an archive handled by archiver, try to detect MIME type.
// This will occur for un-supported archive types and non-archive files. (ex: .deb, .rpm, .txt)
mimeT, err := mimetype.DetectReader(arReader)
if err != nil {
return reader, fmt.Errorf("error detecting MIME type: %w", err)
}
reader.mimeType = mimeType(mimeT.String())
default: // Error identifying archive
// Not an archive handled by archiver.
// Continue with the default reader.
default:
return reader, fmt.Errorf("error identifying archive: %w", err)
}
if _, err = rdr.Seek(0, io.SeekStart); err != nil {
return reader, fmt.Errorf("error seeking to start of file: %w", err)
// Reset the reader to the beginning again to allow the handler to read from the start.
// This is necessary because Identify consumes the reader.
if _, err := bufReader.Seek(0, io.SeekStart); err != nil {
return reader, fmt.Errorf("error resetting reader after archive identification: %w", err)
}
// Disable buffering after initial reads.
// This optimization ensures we don't continue writing to the buffer after the initial reads.
bufReader.DisableBuffering()
reader.BufferedReadSeeker = bufReader
return reader, nil
}
@ -168,7 +163,7 @@ func selectHandler(file fileReader) FileHandler {
// the function will skip processing the file and return nil.
func HandleFile(
ctx logContext.Context,
reader io.ReadCloser,
reader io.Reader,
chunkSkel *sources.Chunk,
reporter sources.ChunkReporter,
options ...func(*fileHandlingConfig),
@ -185,7 +180,6 @@ func HandleFile(
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
config := newFileHandlingConfig(options...)
if config.skipArchives && rdr.isGenericArchive {

View file

@ -1,6 +1,7 @@
package handlers
import (
"io"
"net/http"
"os"
"strings"
@ -42,11 +43,13 @@ func TestHandleFile(t *testing.T) {
}
func BenchmarkHandleFile(b *testing.B) {
file, err := os.Open("testdata/test.tgz")
assert.Nil(b, err)
defer file.Close()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sourceChan := make(chan *sources.Chunk, 1)
file, err := os.Open("testdata/test.tgz")
assert.Nil(b, err)
b.StartTimer()
go func() {
defer close(sourceChan)
@ -57,6 +60,9 @@ func BenchmarkHandleFile(b *testing.B) {
for range sourceChan {
}
b.StopTimer()
_, err = file.Seek(0, io.SeekStart)
assert.NoError(b, err)
}
}
@ -201,6 +207,31 @@ func TestHandleFileAR(t *testing.T) {
assert.Equal(t, wantChunkCount, len(reporter.Ch))
}
func BenchmarkHandleAR(b *testing.B) {
file, err := os.Open("testdata/test.deb")
assert.Nil(b, err)
defer file.Close()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sourceChan := make(chan *sources.Chunk, 1)
b.StartTimer()
go func() {
defer close(sourceChan)
err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: sourceChan})
assert.NoError(b, err)
}()
for range sourceChan {
}
b.StopTimer()
_, err = file.Seek(0, io.SeekStart)
assert.NoError(b, err)
}
}
func TestHandleFileNonArchive(t *testing.T) {
wantChunkCount := 6
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, wantChunkCount)}
@ -220,7 +251,7 @@ func TestExtractTarContentWithEmptyFile(t *testing.T) {
file, err := os.Open("testdata/testdir.zip")
assert.Nil(t, err)
chunkCh := make(chan *sources.Chunk)
chunkCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
@ -234,3 +265,48 @@ func TestExtractTarContentWithEmptyFile(t *testing.T) {
}
assert.Equal(t, wantCount, count)
}
func TestHandleTar(t *testing.T) {
file, err := os.Open("testdata/test.tar")
assert.Nil(t, err)
defer file.Close()
chunkCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
assert.NoError(t, err)
}()
wantCount := 1
count := 0
for range chunkCh {
count++
}
assert.Equal(t, wantCount, count)
}
func BenchmarkHandleTar(b *testing.B) {
file, err := os.Open("testdata/test.tar")
assert.Nil(b, err)
defer file.Close()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sourceChan := make(chan *sources.Chunk, 1)
b.StartTimer()
go func() {
defer close(sourceChan)
err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: sourceChan})
assert.NoError(b, err)
}()
for range sourceChan {
}
b.StopTimer()
_, err = file.Seek(0, io.SeekStart)
assert.NoError(b, err)
}
}

View file

@ -20,7 +20,6 @@ func TestHandleRPMFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newRPMHandler()
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

BIN
pkg/handlers/testdata/test.tar vendored Normal file

Binary file not shown.

View file

@ -0,0 +1,198 @@
package iobuf
import (
"bytes"
"errors"
"fmt"
"io"
)
// BufferedReadSeeker provides a buffered reading interface with seeking capabilities.
// It wraps an io.Reader and optionally an io.Seeker, allowing for efficient
// reading and seeking operations, even on non-seekable underlying readers.
type BufferedReadSeeker struct {
reader io.Reader
seeker io.Seeker // If the reader supports seeking, it's stored here for direct access
buffer *bytes.Buffer // Internal buffer to store read bytes for non-seekable readers
bytesRead int64 // Total number of bytes read from the underlying reader
index int64 // Current position in the virtual stream
// Flag to control buffering. This flag is used to indicate whether buffering is active.
// Buffering is enabled during initial reads (e.g., for MIME type detection and format identification).
// Once these operations are done, buffering should be disabled to prevent further writes to the buffer
// and to optimize subsequent reads directly from the underlying reader. This helps avoid excessive
// memory usage while still providing the necessary functionality for initial detection operations.
activeBuffering bool
}
// NewBufferedReaderSeeker creates and initializes a BufferedReadSeeker.
// It takes an io.Reader and checks if it supports seeking.
// If the reader supports seeking, it is stored in the seeker field.
func NewBufferedReaderSeeker(r io.Reader) *BufferedReadSeeker {
var (
seeker io.Seeker
buffer *bytes.Buffer
activeBuffering = true
)
if s, ok := r.(io.Seeker); ok {
seeker = s
activeBuffering = false
}
const mimeTypeBufferSize = 3072 // Approx buffer size for MIME type detection
if seeker == nil {
buffer = bytes.NewBuffer(make([]byte, 0, mimeTypeBufferSize))
}
return &BufferedReadSeeker{
reader: r,
seeker: seeker,
buffer: buffer,
bytesRead: 0,
index: 0,
activeBuffering: activeBuffering,
}
}
// Read reads len(out) bytes from the reader starting at the current index.
// It handles both seekable and non-seekable underlying readers efficiently.
func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
// For seekable readers, read directly from the underlying reader.
if br.seeker != nil {
n, err := br.reader.Read(out)
br.index += int64(n)
br.bytesRead = max(br.bytesRead, br.index)
return n, err
}
// For non-seekable readers, use buffered reading.
outLen := int64(len(out))
if outLen == 0 {
return 0, nil
}
// If the current read position (br.index) is within the buffer's valid data range,
// read from the buffer. This ensures previously read data (e.g., for mime type detection)
// is included in subsequent reads, providing a consistent view of the reader's content.
if br.index < int64(br.buffer.Len()) {
n := copy(out, br.buffer.Bytes()[br.index:])
br.index += int64(n)
return n, nil
}
if !br.activeBuffering {
// If buffering is not active, read directly from the underlying reader.
n, err := br.reader.Read(out)
br.index += int64(n)
br.bytesRead = max(br.bytesRead, br.index)
return n, err
}
// Ensure there are enough bytes in the buffer to read from.
if outLen+br.index > int64(br.buffer.Len()) {
bytesToRead := int(outLen + br.index - int64(br.buffer.Len()))
readerBytes := make([]byte, bytesToRead)
n, err := br.reader.Read(readerBytes)
br.buffer.Write(readerBytes[:n])
br.bytesRead += int64(n)
if err != nil {
return n, err
}
}
// Ensure the read does not exceed the buffer length.
endIndex := br.index + outLen
bufLen := int64(br.buffer.Len())
if endIndex > bufLen {
endIndex = bufLen
}
if br.index >= bufLen {
return 0, io.EOF
}
n := copy(out, br.buffer.Bytes()[br.index:endIndex])
br.index += int64(n)
return n, nil
}
// Seek sets the offset for the next Read or Write to offset.
// It supports both seekable and non-seekable underlying readers.
func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
if br.seeker != nil {
// Use the underlying Seeker if available.
newIndex, err := br.seeker.Seek(offset, whence)
if err != nil {
return 0, fmt.Errorf("error seeking in reader: %w", err)
}
if newIndex > br.bytesRead {
br.bytesRead = newIndex
}
return newIndex, nil
}
// Manual seeking for non-seekable readers.
newIndex := br.index
const bufferSize = 64 * 1024 // 64KB chunk size for reading
switch whence {
case io.SeekStart:
newIndex = offset
case io.SeekCurrent:
newIndex += offset
case io.SeekEnd:
// Read the entire reader to determine its length
buffer := make([]byte, bufferSize)
for {
n, err := br.reader.Read(buffer)
if n > 0 {
if br.activeBuffering {
br.buffer.Write(buffer[:n])
}
br.bytesRead += int64(n)
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return 0, err
}
}
newIndex = min(br.bytesRead+offset, br.bytesRead)
default:
return 0, errors.New("invalid whence value")
}
if newIndex < 0 {
return 0, errors.New("can not seek to before start of reader")
}
br.index = newIndex
return newIndex, nil
}
// ReadAt reads len(out) bytes into out starting at offset off in the underlying input source.
// It uses Seek and Read to implement random access reading.
func (br *BufferedReadSeeker) ReadAt(out []byte, offset int64) (int, error) {
startIndex, err := br.Seek(offset, io.SeekStart)
if err != nil {
return 0, err
}
if startIndex != offset {
return 0, io.EOF
}
return br.Read(out)
}
// DisableBuffering stops the buffering process.
// This is useful after initial reads (e.g., for MIME type detection and format identification)
// to prevent further writes to the buffer, optimizing subsequent reads.
func (br *BufferedReadSeeker) DisableBuffering() { br.activeBuffering = false }

View file

@ -0,0 +1,362 @@
package iobuf
import (
"bytes"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBufferedReaderSeekerRead(t *testing.T) {
tests := []struct {
name string
reader io.Reader
activeBuffering bool
reads []int
expectedReads []int
expectedBytes [][]byte
expectedBytesRead int64
expectedIndex int64
expectedBuffer []byte
expectedError error
}{
{
name: "read from seekable reader",
reader: strings.NewReader("test data"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
expectedBytesRead: 4,
expectedIndex: 4,
},
{
name: "read from non-seekable reader with buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
expectedBytesRead: 4,
expectedIndex: 4,
expectedBuffer: []byte("test"),
},
{
name: "read from non-seekable reader without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
expectedBytesRead: 4,
expectedIndex: 4,
},
{
name: "read beyond buffer",
reader: strings.NewReader("test data"),
activeBuffering: true,
reads: []int{10},
expectedReads: []int{9},
expectedBytes: [][]byte{[]byte("test data")},
expectedBytesRead: 9,
expectedIndex: 9,
},
{
name: "read with empty reader",
reader: strings.NewReader(""),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{0},
expectedBytes: [][]byte{[]byte("")},
expectedBytesRead: 0,
expectedIndex: 0,
expectedError: io.EOF,
},
{
name: "read exact buffer size",
reader: strings.NewReader("test"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
expectedBytesRead: 4,
expectedIndex: 4,
},
{
name: "read less than buffer size",
reader: strings.NewReader("te"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{2},
expectedBytes: [][]byte{[]byte("te")},
expectedBytesRead: 2,
expectedIndex: 2,
},
{
name: "read more than buffer size without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
expectedBytesRead: 4,
expectedIndex: 4,
},
{
name: "multiple reads with buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: true,
reads: []int{4, 5},
expectedReads: []int{4, 5},
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
expectedBytesRead: 9,
expectedIndex: 9,
expectedBuffer: []byte("test data"),
},
{
name: "multiple reads without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4, 5},
expectedReads: []int{4, 5},
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
expectedBytesRead: 9,
expectedIndex: 9,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
brs := NewBufferedReaderSeeker(tt.reader)
brs.activeBuffering = tt.activeBuffering
for i, readSize := range tt.reads {
buf := make([]byte, readSize)
n, err := brs.Read(buf)
assert.Equal(t, tt.expectedReads[i], n, "read %d: unexpected number of bytes read", i+1)
assert.Equal(t, tt.expectedBytes[i], buf[:n], "read %d: unexpected bytes", i+1)
if i == len(tt.reads)-1 {
if tt.expectedError != nil {
assert.ErrorIs(t, err, tt.expectedError)
} else {
assert.NoError(t, err)
}
}
}
assert.Equal(t, tt.expectedBytesRead, brs.bytesRead)
assert.Equal(t, tt.expectedIndex, brs.index)
if brs.buffer != nil && len(tt.expectedBuffer) > 0 {
assert.Equal(t, tt.expectedBuffer, brs.buffer.Bytes())
} else {
assert.Nil(t, tt.expectedBuffer)
}
})
}
}
func TestBufferedReaderSeekerSeek(t *testing.T) {
tests := []struct {
name string
reader io.Reader
offset int64
whence int
expectedPos int64
expectedErr bool
expectedRead []byte
}{
{
name: "seek on seekable reader with SeekStart",
reader: strings.NewReader("test data"),
offset: 4,
whence: io.SeekStart,
expectedPos: 4,
expectedErr: false,
expectedRead: []byte(" dat"),
},
{
name: "seek on seekable reader with SeekCurrent",
reader: strings.NewReader("test data"),
offset: 4,
whence: io.SeekCurrent,
expectedPos: 4,
expectedErr: false,
expectedRead: []byte(" dat"),
},
{
name: "seek on seekable reader with SeekEnd",
reader: strings.NewReader("test data"),
offset: -4,
whence: io.SeekEnd,
expectedPos: 5,
expectedErr: false,
expectedRead: []byte("data"),
},
{
name: "seek on non-seekable reader with SeekStart",
reader: bytes.NewBufferString("test data"),
offset: 4,
whence: io.SeekStart,
expectedPos: 4,
expectedErr: false,
expectedRead: []byte{},
},
{
name: "seek on non-seekable reader with SeekCurrent",
reader: bytes.NewBufferString("test data"),
offset: 4,
whence: io.SeekCurrent,
expectedPos: 4,
expectedErr: false,
expectedRead: []byte{},
},
{
name: "seek on non-seekable reader with SeekEnd",
reader: bytes.NewBufferString("test data"),
offset: -4,
whence: io.SeekEnd,
expectedPos: 5,
expectedErr: false,
expectedRead: []byte{},
},
{
name: "seek to negative position",
reader: strings.NewReader("test data"),
offset: -1,
whence: io.SeekStart,
expectedPos: 0,
expectedErr: true,
expectedRead: nil,
},
{
name: "seek beyond EOF on non-seekable reader",
reader: bytes.NewBufferString("test data"),
offset: 20,
whence: io.SeekEnd,
expectedPos: 9,
expectedErr: false,
expectedRead: []byte{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
brs := NewBufferedReaderSeeker(tt.reader)
pos, err := brs.Seek(tt.offset, tt.whence)
if tt.expectedErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expectedPos, pos)
if len(tt.expectedRead) > 0 {
buf := make([]byte, len(tt.expectedRead))
nn, err := brs.Read(buf)
assert.NoError(t, err)
assert.Equal(t, len(tt.expectedRead), nn)
assert.Equal(t, tt.expectedRead, buf[:nn])
}
})
}
}
func TestBufferedReaderSeekerReadAt(t *testing.T) {
tests := []struct {
name string
reader io.Reader
offset int64
length int
expectedN int
expectErr bool
expectedOut []byte
}{
{
name: "read within buffer on seekable reader",
reader: strings.NewReader("test data"),
offset: 5,
length: 4,
expectedN: 4,
expectedOut: []byte("data"),
},
{
name: "read within buffer on non-seekable reader",
reader: bytes.NewBufferString("test data"),
offset: 5,
length: 4,
expectedN: 4,
expectedOut: []byte("data"),
},
{
name: "read beyond buffer",
reader: strings.NewReader("test data"),
offset: 9,
length: 1,
expectedN: 0,
expectErr: true,
expectedOut: []byte{},
},
{
name: "read at start",
reader: strings.NewReader("test data"),
offset: 0,
length: 4,
expectedN: 4,
expectedOut: []byte("test"),
},
{
name: "read with zero length",
reader: strings.NewReader("test data"),
offset: 0,
length: 0,
expectedN: 0,
expectedOut: []byte{},
},
{
name: "read negative offset",
reader: strings.NewReader("test data"),
offset: -1,
length: 4,
expectedN: 0,
expectErr: true,
expectedOut: []byte{},
},
{
name: "read beyond end on non-seekable reader",
reader: bytes.NewBufferString("test data"),
offset: 20,
length: 4,
expectedN: 0,
expectErr: true,
expectedOut: []byte{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
brs := NewBufferedReaderSeeker(tt.reader)
out := make([]byte, tt.length)
n, err := brs.ReadAt(out, tt.offset)
if tt.expectErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expectedN, n)
assert.Equal(t, tt.expectedOut, out[:n])
})
}
}

View file

@ -1,84 +0,0 @@
package readers
import (
"fmt"
"io"
bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer"
)
// Compile time check to ensure that BufferedFileReader implements io.ReaderAt.
var _ io.ReaderAt = (*BufferedFileReader)(nil)
// BufferedFileReader provides random access read, seek, and close capabilities on top of the BufferedFileWriter.
// It combines the functionality of BufferedFileWriter for buffered writing, io.ReadSeekCloser for
// random access reading and seeking.
type BufferedFileReader struct {
bufWriter *bufferedfilewriter.BufferedFileWriter
reader io.ReadSeekCloser
}
// NewBufferedFileReader initializes a BufferedFileReader from an io.Reader by using
// the BufferedFileWriter's functionality to read and store data, then setting up a io.ReadSeekCloser
// for random access reading and seeking.
// Close should be called when the BufferedFileReader is no longer needed.
// It returns the initialized BufferedFileReader and any error encountered during the process.
func NewBufferedFileReader(r io.Reader) (*BufferedFileReader, error) {
writer, err := bufferedfilewriter.NewFromReader(r)
if err != nil {
return nil, fmt.Errorf("error creating BufferedFileReader: %w", err)
}
// Ensure that the BufferedFileWriter is in read-only mode.
if err := writer.CloseForWriting(); err != nil {
return nil, err
}
rdr, err := writer.ReadSeekCloser()
if err != nil {
return nil, err
}
return &BufferedFileReader{writer, rdr}, nil
}
// Close the BufferedFileReader.
// It should be called when the BufferedFileReader is no longer needed.
// Note that closing the BufferedFileReader does not affect the underlying bytes.Reader,
// which can still be used for reading, seeking, and reading at specific positions.
// Close is a no-op for the bytes.Reader.
func (b *BufferedFileReader) Close() error { return b.reader.Close() }
// Read reads up to len(p) bytes into p from the underlying reader.
// It returns the number of bytes read and any error encountered.
// If the reader reaches the end of the available data, Read returns 0, io.EOF.
// It implements the io.Reader interface.
func (b *BufferedFileReader) Read(p []byte) (int, error) { return b.reader.Read(p) }
// Seek sets the offset for the next Read operation on the underlying reader.
// The offset is interpreted according to the whence parameter:
// - io.SeekStart means relative to the start of the file
// - io.SeekCurrent means relative to the current offset
// - io.SeekEnd means relative to the end of the file
//
// Seek returns the new offset and any error encountered.
// It implements the io.Seeker interface.
func (b *BufferedFileReader) Seek(offset int64, whence int) (int64, error) {
return b.reader.Seek(offset, whence)
}
// ReadAt reads len(p) bytes from the underlying io.ReadSeekCloser starting at byte offset off.
// It returns the number of bytes read and any error encountered.
// If the io.ReadSeekCloser reaches the end of the available data before len(p) bytes are read,
// ReadAt returns the number of bytes read and io.EOF.
// It implements the io.ReaderAt interface.
func (b *BufferedFileReader) ReadAt(p []byte, off int64) (n int, err error) {
_, err = b.reader.Seek(off, io.SeekStart)
if err != nil {
return 0, err
}
return b.reader.Read(p)
}
// Size returns the total size of the data stored in the BufferedFileReader.
func (b *BufferedFileReader) Size() int { return b.bufWriter.Len() }

View file

@ -1,98 +0,0 @@
package readers
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBufferedFileReader(t *testing.T) {
t.Parallel()
data := []byte("Hello, World!")
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(data))
assert.NoError(t, err)
// Test Read.
buffer := make([]byte, len(data))
n, err := bufferReadSeekCloser.Read(buffer)
assert.NoError(t, err)
assert.Equal(t, len(data), n)
assert.Equal(t, data, buffer)
// Test Seek.
offset := 7
seekPos, err := bufferReadSeekCloser.Seek(int64(offset), io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(offset), seekPos)
// Test ReadAt.
buffer = make([]byte, len(data)-offset)
n, err = bufferReadSeekCloser.ReadAt(buffer, int64(offset))
assert.NoError(t, err)
assert.Equal(t, len(data)-offset, n)
assert.Equal(t, data[offset:], buffer)
// Test Close.
err = bufferReadSeekCloser.Close()
assert.NoError(t, err)
}
func TestBufferedFileReaderClose(t *testing.T) {
t.Parallel()
data := []byte("Hello, World!")
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(data))
assert.NoError(t, err)
err = bufferReadSeekCloser.Close()
assert.NoError(t, err)
// Read should NOT return any data after closing the reader.
buffer := make([]byte, len(data))
n, err := bufferReadSeekCloser.Read(buffer)
assert.ErrorIs(t, err, io.EOF)
assert.Equal(t, 0, n)
}
func TestBufferedFileReaderReadFromFile(t *testing.T) {
t.Parallel()
// Create a large byte slice to simulate data exceeding the threshold.
largeData := make([]byte, 1024*1024) // 1 MB
for i := range largeData {
largeData[i] = byte(i % 256)
}
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(largeData))
assert.NoError(t, err)
defer bufferReadSeekCloser.Close()
// Test Read.
buffer := make([]byte, len(largeData))
n, err := bufferReadSeekCloser.Read(buffer)
assert.NoError(t, err)
assert.Equal(t, len(largeData), n)
assert.Equal(t, largeData, buffer)
// Test Seek.
offset := 512 * 1024 // 512 KB
seekPos, err := bufferReadSeekCloser.Seek(int64(offset), io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(offset), seekPos)
// Test ReadAt.
buffer = make([]byte, len(largeData)-offset)
n, err = bufferReadSeekCloser.ReadAt(buffer, int64(offset))
assert.NoError(t, err)
assert.Equal(t, len(largeData)-offset, n)
assert.Equal(t, largeData[offset:], buffer)
// Test Close.
err = bufferReadSeekCloser.Close()
assert.NoError(t, err)
}

View file

@ -92,7 +92,7 @@ func createReaderFn(config *chunkReaderConfig) ChunkReader {
}
func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) <-chan ChunkResult {
const channelSize = 1
const channelSize = 64
chunkReader := bufio.NewReaderSize(reader, config.chunkSize)
chunkResultChan := make(chan ChunkResult, channelSize)

View file

@ -166,6 +166,7 @@ func (s *Source) scanFile(ctx context.Context, path string, chunksChan chan *sou
if err != nil {
return fmt.Errorf("unable to open file: %w", err)
}
defer inputFile.Close()
logger.V(3).Info("scanning file")

View file

@ -1238,29 +1238,47 @@ func (s *Git) handleBinary(ctx context.Context, gitDir string, reporter sources.
}
cmd := exec.Command("git", "-C", gitDir, "cat-file", "blob", commitHash.String()+":"+path)
stdout, err := s.executeCatFileCmd(cmd)
if err != nil {
return err
}
done := make(chan error, 1)
// Read from stdout to prevent the pipe buffer from filling up and causing the command to hang.
// This allows us to stream the file contents to the handler.
go func() {
defer close(done)
done <- handlers.HandleFile(ctx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives))
}()
// Close to signal that we are done writing to the pipe, which allows the reading goroutine to finish.
if closeErr := stdout.Close(); closeErr != nil && !errors.Is(closeErr, os.ErrClosed) {
ctx.Logger().Error(fmt.Errorf("error closing stdout: %w", closeErr), "closing stdout failed")
}
if waitErr := cmd.Wait(); waitErr != nil {
return fmt.Errorf("error waiting for git cat-file: %w", waitErr)
}
// Wait for the command to finish and the handler to complete.
// Capture any error from the file handling process.
return <-done
}
func (s *Git) executeCatFileCmd(cmd *exec.Cmd) (io.ReadCloser, error) {
var stderr bytes.Buffer
cmd.Stderr = &stderr
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes())
return nil, fmt.Errorf("error running git cat-file: %w\n%s", err, stderr.Bytes())
}
defer func() {
if err = cmd.Wait(); err != nil {
ctx.Logger().Error(fmt.Errorf(
"error waiting for command: command=%s, stderr=%s, commit=%s: %w",
cmd.String(), stderr.String(), commitHash.String(), err,
), "waiting for command failed")
}
}()
if err := cmd.Start(); err != nil {
return fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes())
return nil, fmt.Errorf("error starting git cat-file: %w\n%s", err, stderr.Bytes())
}
return handlers.HandleFile(fileCtx, stdout, chunkSkel, reporter, handlers.WithSkipArchives(s.skipArchives))
return stdout, nil
}
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {

View file

@ -355,6 +355,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
}
return nil
}
defer res.Body.Close()
email := "Unknown"
if obj.Owner != nil {