[bug]- Invalid Seek for Non-Seekable Readers (#3095)

* inital work

* fix and add tests

* uncomment

* fix seek end

* use buffer pool

* revert timeout

* make linter happy

* More linting :()
This commit is contained in:
ahrav 2024-07-24 19:08:56 -07:00 committed by GitHub
parent 4a8b213651
commit ebfbd21707
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 344 additions and 134 deletions

View file

@ -20,6 +20,7 @@ func TestHandleARFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newARHandler()
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

View file

@ -111,6 +111,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
return h.openArchive(ctx, depth+1, rdr, archiveChan)
case archiver.Extractor:
@ -194,6 +195,7 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
h.metrics.incFilesProcessed()
h.metrics.observeFileSize(fileSize)

View file

@ -89,6 +89,8 @@ func TestArchiveHandler(t *testing.T) {
if err != nil {
t.Errorf("error creating reusable reader: %s", err)
}
defer newReader.Close()
archiveChan, err := handler.HandleFile(logContext.Background(), newReader)
if testCase.expectErr {
assert.NoError(t, err)
@ -119,6 +121,7 @@ func TestOpenInvalidArchive(t *testing.T) {
rdr, err := newFileReader(io.NopCloser(reader))
assert.NoError(t, err)
defer rdr.Close()
archiveChan := make(chan []byte)

View file

@ -20,6 +20,7 @@ func TestHandleNonArchiveFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newDefaultHandler(defaultHandlerType)
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

View file

@ -59,7 +59,7 @@ func newMimeTypeReaderFromFileReader(r fileReader) mimeTypeReader {
// newMimeTypeReader creates a new mimeTypeReader from an io.Reader.
// It uses a bufio.Reader to perform MIME type detection on the input reader
// without consuming it, by peeking into the first 512 bytes of the input.
// without consuming it, by peeking into the first 3072 bytes of the input.
// This encapsulates both the original reader and the detected MIME type information.
// This function is particularly useful for specialized archive handlers
// that need to pass extracted content to the default handler without modifying the original reader.
@ -84,10 +84,6 @@ func newFileReader(r io.Reader) (fileReader, error) {
fReader.BufferedReadSeeker = iobuf.NewBufferedReaderSeeker(r)
// Disable buffering after initial reads.
// This optimization ensures we don't continue writing to the buffer after the initial reads.
defer fReader.DisableBuffering()
mime, err := mimetype.DetectReader(fReader)
if err != nil {
return fReader, fmt.Errorf("unable to detect MIME type: %w", err)
@ -281,6 +277,7 @@ func HandleFile(
}
return fmt.Errorf("error creating custom reader: %w", err)
}
defer rdr.Close()
mimeT := mimeType(rdr.mime.String())
config := newFileHandlingConfig(options...)

View file

@ -27,21 +27,101 @@ func TestHandleFileCancelledContext(t *testing.T) {
}
func TestHandleFile(t *testing.T) {
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 2)}
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 513)}
// Only one chunk is sent on the channel.
// TODO: Embed a zip without making an HTTP request.
resp, err := http.Get("https://raw.githubusercontent.com/bill-rich/bad-secrets/master/aws-canary-creds.zip")
assert.NoError(t, err)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
}()
assert.Equal(t, 0, len(reporter.Ch))
assert.NoError(t, HandleFile(context.Background(), resp.Body, &sources.Chunk{}, reporter))
assert.Equal(t, 1, len(reporter.Ch))
}
func TestHandleHTTPJson(t *testing.T) {
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm_random_data.json")
assert.NoError(t, err)
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
}()
chunkCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
assert.NoError(t, err)
}()
wantCount := 513
count := 0
for range chunkCh {
count++
}
assert.Equal(t, wantCount, count)
}
func TestHandleHTTPJsonZip(t *testing.T) {
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
assert.NoError(t, err)
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
}()
chunkCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
assert.NoError(t, err)
}()
wantCount := 513
count := 0
for range chunkCh {
count++
}
assert.Equal(t, wantCount, count)
}
func BenchmarkHandleHTTPJsonZip(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
func() {
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
assert.NoError(b, err)
defer func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
}()
chunkCh := make(chan *sources.Chunk, 1)
b.StartTimer()
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
assert.NoError(b, err)
}()
for range chunkCh {
}
b.StopTimer()
}()
}
}
func BenchmarkHandleFile(b *testing.B) {
file, err := os.Open("testdata/test.tgz")
assert.Nil(b, err)

View file

@ -20,6 +20,7 @@ func TestHandleRPMFile(t *testing.T) {
rdr, err := newFileReader(file)
assert.NoError(t, err)
defer rdr.Close()
handler := newRPMHandler()
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)

View file

@ -1,123 +1,144 @@
package iobuf
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/buffer"
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/pool"
"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
)
const defaultBufferSize = 1 << 16 // 64KB
var defaultBufferPool *pool.Pool
func init() { defaultBufferPool = pool.NewBufferPool(defaultBufferSize) }
// BufferedReadSeeker provides a buffered reading interface with seeking capabilities.
// It wraps an io.Reader and optionally an io.Seeker, allowing for efficient
// reading and seeking operations, even on non-seekable underlying readers.
//
// For small amounts of data, it uses an in-memory buffer (bytes.Buffer) to store
// read bytes. When the amount of data exceeds a specified threshold, it switches
// to disk-based buffering using a temporary file. This approach balances memory
// usage and performance, allowing efficient handling of both small and large data streams.
//
// The struct manages the transition between in-memory and disk-based buffering
// transparently, providing a seamless reading and seeking experience regardless
// of the underlying data size or the seekability of the original reader.
//
// If the underlying reader is seekable, direct seeking operations are performed
// on it. For non-seekable readers, seeking is emulated using the buffer or
// temporary file.
type BufferedReadSeeker struct {
reader io.Reader
seeker io.Seeker // If the reader supports seeking, it's stored here for direct access
buffer *bytes.Buffer // Internal buffer to store read bytes for non-seekable readers
bufPool *pool.Pool // Pool for storing buffers for reuse.
buf *buffer.Buffer // Buffer for storing data under the threshold in memory.
bytesRead int64 // Total number of bytes read from the underlying reader
index int64 // Current position in the virtual stream
// Flag to control buffering. This flag is used to indicate whether buffering is active.
// Buffering is enabled during initial reads (e.g., for MIME type detection and format identification).
// Once these operations are done, buffering should be disabled to prevent further writes to the buffer
// and to optimize subsequent reads directly from the underlying reader. This helps avoid excessive
// memory usage while still providing the necessary functionality for initial detection operations.
activeBuffering bool
threshold int64 // Threshold for switching to file buffering
tempFile *os.File // Temporary file for disk-based buffering
tempFileName string // Name of the temporary file
diskBufferSize int64 // Size of data written to disk
// Fields to provide a quick way to determine the total size of the reader
// without having to seek.
totalSize int64 // Total size of the reader
sizeKnown bool // Whether the total size of the reader is known
}
// NewBufferedReaderSeeker creates and initializes a BufferedReadSeeker.
// It takes an io.Reader and checks if it supports seeking.
// If the reader supports seeking, it is stored in the seeker field.
func NewBufferedReaderSeeker(r io.Reader) *BufferedReadSeeker {
var (
seeker io.Seeker
buffer *bytes.Buffer
activeBuffering = true
)
if s, ok := r.(io.Seeker); ok {
seeker = s
activeBuffering = false
}
const defaultThreshold = 1 << 24 // 16MB threshold for switching to file buffering
const mimeTypeBufferSize = 3072 // Approx buffer size for MIME type detection
seeker, _ := r.(io.Seeker)
var buf *buffer.Buffer
if seeker == nil {
buffer = bytes.NewBuffer(make([]byte, 0, mimeTypeBufferSize))
buf = defaultBufferPool.Get()
}
return &BufferedReadSeeker{
reader: r,
seeker: seeker,
buffer: buffer,
bytesRead: 0,
index: 0,
activeBuffering: activeBuffering,
reader: r,
seeker: seeker,
bufPool: defaultBufferPool,
buf: buf,
threshold: defaultThreshold,
}
}
// Read reads len(out) bytes from the reader starting at the current index.
// It handles both seekable and non-seekable underlying readers efficiently.
func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
// For seekable readers, read directly from the underlying reader.
if br.seeker != nil {
// For seekable readers, read directly from the underlying reader.
n, err := br.reader.Read(out)
br.index += int64(n)
br.bytesRead = max(br.bytesRead, br.index)
return n, err
}
// For non-seekable readers, use buffered reading.
outLen := int64(len(out))
if outLen == 0 {
return 0, nil
}
// If the current read position (br.index) is within the buffer's valid data range,
// read from the buffer. This ensures previously read data (e.g., for mime type detection)
// is included in subsequent reads, providing a consistent view of the reader's content.
if br.index < int64(br.buffer.Len()) {
n := copy(out, br.buffer.Bytes()[br.index:])
br.index += int64(n)
return n, nil
}
if !br.activeBuffering {
// If buffering is not active, read directly from the underlying reader.
n, err := br.reader.Read(out)
br.index += int64(n)
br.bytesRead = max(br.bytesRead, br.index)
return n, err
}
// Ensure there are enough bytes in the buffer to read from.
if outLen+br.index > int64(br.buffer.Len()) {
bytesToRead := int(outLen + br.index - int64(br.buffer.Len()))
readerBytes := make([]byte, bytesToRead)
n, err := br.reader.Read(readerBytes)
br.buffer.Write(readerBytes[:n])
br.bytesRead += int64(n)
if err != nil {
return n, err
if n > 0 {
br.bytesRead += int64(n)
}
return n, err
}
// Ensure the read does not exceed the buffer length.
endIndex := br.index + outLen
bufLen := int64(br.buffer.Len())
if endIndex > bufLen {
endIndex = bufLen
var (
totalBytesRead int
err error
)
// If the current read position is within the in-memory buffer.
if br.index < int64(br.buf.Len()) {
totalBytesRead = copy(out, br.buf.Bytes()[br.index:])
br.index += int64(totalBytesRead)
if totalBytesRead == len(out) {
return totalBytesRead, nil
}
out = out[totalBytesRead:]
}
if br.index >= bufLen {
return 0, io.EOF
// If we've exceeded the in-memory threshold and have a temp file.
if br.tempFile != nil && br.index < br.diskBufferSize {
if _, err := br.tempFile.Seek(br.index-int64(br.buf.Len()), io.SeekStart); err != nil {
return totalBytesRead, err
}
m, err := br.tempFile.Read(out)
totalBytesRead += m
br.index += int64(m)
if err != nil && !errors.Is(err, io.EOF) {
return totalBytesRead, err
}
if totalBytesRead == len(out) {
return totalBytesRead, nil
}
out = out[totalBytesRead:]
}
n := copy(out, br.buffer.Bytes()[br.index:endIndex])
br.index += int64(n)
return n, nil
if len(out) == 0 {
return totalBytesRead, nil
}
// If we still need to read more data.
var raderBytes int
raderBytes, err = br.reader.Read(out)
totalBytesRead += raderBytes
br.index += int64(raderBytes)
if writeErr := br.writeData(out[:raderBytes]); writeErr != nil {
return totalBytesRead, writeErr
}
if errors.Is(err, io.EOF) {
br.totalSize = br.bytesRead
br.sizeKnown = true
}
return totalBytesRead, err
}
// Seek sets the offset for the next Read or Write to offset.
@ -125,45 +146,25 @@ func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
if br.seeker != nil {
// Use the underlying Seeker if available.
newIndex, err := br.seeker.Seek(offset, whence)
if err != nil {
return 0, fmt.Errorf("error seeking in reader: %w", err)
}
if newIndex > br.bytesRead {
br.bytesRead = newIndex
}
return newIndex, nil
return br.seeker.Seek(offset, whence)
}
// Manual seeking for non-seekable readers.
newIndex := br.index
const bufferSize = 64 * 1024 // 64KB chunk size for reading
switch whence {
case io.SeekStart:
newIndex = offset
case io.SeekCurrent:
newIndex += offset
case io.SeekEnd:
// Read the entire reader to determine its length
buffer := make([]byte, bufferSize)
for {
n, err := br.reader.Read(buffer)
if n > 0 {
if br.activeBuffering {
br.buffer.Write(buffer[:n])
}
br.bytesRead += int64(n)
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
// If we already know the total size, we can use it directly.
if !br.sizeKnown {
if err := br.readToEnd(); err != nil {
return 0, err
}
}
newIndex = min(br.bytesRead+offset, br.bytesRead)
newIndex = br.totalSize + offset
default:
return 0, errors.New("invalid whence value")
}
@ -172,27 +173,161 @@ func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
return 0, errors.New("can not seek to before start of reader")
}
// For non-seekable readers, we need to ensure we've read up to the new index.
if br.seeker == nil && newIndex > br.bytesRead {
if err := br.readUntil(newIndex); err != nil {
return 0, err
}
}
br.index = newIndex
// Update bytesRead only if we've moved beyond what we've read so far.
if br.index > br.bytesRead {
br.bytesRead = br.index
}
return newIndex, nil
}
func (br *BufferedReadSeeker) readToEnd() error {
buf := br.bufPool.Get()
defer br.bufPool.Put(buf)
for {
n, err := io.CopyN(buf, br.reader, defaultBufferSize)
if n > 0 {
// Write the data from the buffer.
if writeErr := br.writeData(buf.Bytes()[:n]); writeErr != nil {
return writeErr
}
}
// Reset the buffer for the next iteration.
buf.Reset()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return err
}
}
br.totalSize = br.bytesRead
br.sizeKnown = true
return nil
}
func (br *BufferedReadSeeker) writeData(data []byte) error {
_, err := br.buf.Write(data)
if err != nil {
return err
}
br.bytesRead += int64(len(data))
// Check if we've reached or exceeded the threshold.
if br.buf.Len() < int(br.threshold) {
return nil
}
if br.tempFile == nil {
if err := br.createTempFile(); err != nil {
return err
}
}
// Flush the buffer to disk.
return br.flushBufferToDisk()
}
func (br *BufferedReadSeeker) readUntil(index int64) error {
buf := br.bufPool.Get()
defer br.bufPool.Put(buf)
for br.bytesRead < index {
remaining := index - br.bytesRead
bufSize := int64(defaultBufferSize)
if remaining < bufSize {
bufSize = remaining
}
n, err := io.CopyN(buf, br, bufSize)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if n == 0 {
break
}
buf.Reset()
}
return nil
}
func (br *BufferedReadSeeker) createTempFile() error {
tempFile, err := os.CreateTemp(os.TempDir(), cleantemp.MkFilename())
if err != nil {
return err
}
br.tempFile = tempFile
br.tempFileName = tempFile.Name()
return nil
}
func (br *BufferedReadSeeker) flushBufferToDisk() error {
if _, err := br.buf.WriteTo(br.tempFile); err != nil {
return err
}
br.diskBufferSize = int64(br.buf.Len())
return nil
}
// ReadAt reads len(out) bytes into out starting at offset off in the underlying input source.
// It uses Seek and Read to implement random access reading.
func (br *BufferedReadSeeker) ReadAt(out []byte, offset int64) (int, error) {
startIndex, err := br.Seek(offset, io.SeekStart)
if err != nil {
if br.seeker != nil {
// Use the underlying Seeker if available.
_, err := br.Seek(offset, io.SeekStart)
if err != nil {
return 0, err
}
return br.Read(out)
}
// For non-seekable readers, use our buffering logic.
currentIndex := br.index
if _, err := br.Seek(offset, io.SeekStart); err != nil {
return 0, err
}
if startIndex != offset {
return 0, io.EOF
n, err := br.Read(out)
if err != nil {
return n, err
}
return br.Read(out)
// Seek back to the original position.
if _, err = br.Seek(currentIndex, io.SeekStart); err != nil {
return n, err
}
return n, err
}
// DisableBuffering stops the buffering process.
// This is useful after initial reads (e.g., for MIME type detection and format identification)
// to prevent further writes to the buffer, optimizing subsequent reads.
func (br *BufferedReadSeeker) DisableBuffering() { br.activeBuffering = false }
// Close closes the BufferedReadSeeker and releases any resources used.
// It closes the temporary file if one was created and removes it from disk and
// returns the buffer to the pool.
func (br *BufferedReadSeeker) Close() error {
if br.buf != nil {
br.bufPool.Put(br.buf)
}
if br.tempFile != nil {
br.tempFile.Close()
return os.Remove(br.tempFileName)
}
return nil
}

View file

@ -13,7 +13,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
tests := []struct {
name string
reader io.Reader
activeBuffering bool
reads []int
expectedReads []int
expectedBytes [][]byte
@ -25,7 +24,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read from seekable reader",
reader: strings.NewReader("test data"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
@ -35,7 +33,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read from non-seekable reader with buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
@ -46,7 +43,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read from non-seekable reader without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
@ -56,7 +52,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read beyond buffer",
reader: strings.NewReader("test data"),
activeBuffering: true,
reads: []int{10},
expectedReads: []int{9},
expectedBytes: [][]byte{[]byte("test data")},
@ -66,7 +61,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read with empty reader",
reader: strings.NewReader(""),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{0},
expectedBytes: [][]byte{[]byte("")},
@ -77,7 +71,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read exact buffer size",
reader: strings.NewReader("test"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
@ -87,7 +80,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read less than buffer size",
reader: strings.NewReader("te"),
activeBuffering: true,
reads: []int{4},
expectedReads: []int{2},
expectedBytes: [][]byte{[]byte("te")},
@ -97,7 +89,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "read more than buffer size without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4},
expectedReads: []int{4},
expectedBytes: [][]byte{[]byte("test")},
@ -107,7 +98,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "multiple reads with buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: true,
reads: []int{4, 5},
expectedReads: []int{4, 5},
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
@ -118,7 +108,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
{
name: "multiple reads without buffering",
reader: bytes.NewBufferString("test data"),
activeBuffering: false,
reads: []int{4, 5},
expectedReads: []int{4, 5},
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
@ -132,7 +121,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
t.Parallel()
brs := NewBufferedReaderSeeker(tt.reader)
brs.activeBuffering = tt.activeBuffering
for i, readSize := range tt.reads {
buf := make([]byte, readSize)
@ -151,10 +139,12 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
}
assert.Equal(t, tt.expectedBytesRead, brs.bytesRead)
assert.Equal(t, tt.expectedIndex, brs.index)
if brs.seeker == nil {
assert.Equal(t, tt.expectedIndex, brs.index)
}
if brs.buffer != nil && len(tt.expectedBuffer) > 0 {
assert.Equal(t, tt.expectedBuffer, brs.buffer.Bytes())
if brs.buf != nil && len(tt.expectedBuffer) > 0 {
assert.Equal(t, tt.expectedBuffer, brs.buf.Bytes())
} else {
assert.Nil(t, tt.expectedBuffer)
}
@ -240,7 +230,7 @@ func TestBufferedReaderSeekerSeek(t *testing.T) {
reader: bytes.NewBufferString("test data"),
offset: 20,
whence: io.SeekEnd,
expectedPos: 9,
expectedPos: 29,
expectedErr: false,
expectedRead: []byte{},
},