mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[bug]- Invalid Seek for Non-Seekable Readers (#3095)
* inital work * fix and add tests * uncomment * fix seek end * use buffer pool * revert timeout * make linter happy * More linting :()
This commit is contained in:
parent
4a8b213651
commit
ebfbd21707
9 changed files with 344 additions and 134 deletions
|
@ -20,6 +20,7 @@ func TestHandleARFile(t *testing.T) {
|
||||||
|
|
||||||
rdr, err := newFileReader(file)
|
rdr, err := newFileReader(file)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
handler := newARHandler()
|
handler := newARHandler()
|
||||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||||
|
|
|
@ -111,6 +111,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
|
||||||
}
|
}
|
||||||
return fmt.Errorf("error creating custom reader: %w", err)
|
return fmt.Errorf("error creating custom reader: %w", err)
|
||||||
}
|
}
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
return h.openArchive(ctx, depth+1, rdr, archiveChan)
|
return h.openArchive(ctx, depth+1, rdr, archiveChan)
|
||||||
case archiver.Extractor:
|
case archiver.Extractor:
|
||||||
|
@ -194,6 +195,7 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
|
||||||
}
|
}
|
||||||
return fmt.Errorf("error creating custom reader: %w", err)
|
return fmt.Errorf("error creating custom reader: %w", err)
|
||||||
}
|
}
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
h.metrics.incFilesProcessed()
|
h.metrics.incFilesProcessed()
|
||||||
h.metrics.observeFileSize(fileSize)
|
h.metrics.observeFileSize(fileSize)
|
||||||
|
|
|
@ -89,6 +89,8 @@ func TestArchiveHandler(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error creating reusable reader: %s", err)
|
t.Errorf("error creating reusable reader: %s", err)
|
||||||
}
|
}
|
||||||
|
defer newReader.Close()
|
||||||
|
|
||||||
archiveChan, err := handler.HandleFile(logContext.Background(), newReader)
|
archiveChan, err := handler.HandleFile(logContext.Background(), newReader)
|
||||||
if testCase.expectErr {
|
if testCase.expectErr {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -119,6 +121,7 @@ func TestOpenInvalidArchive(t *testing.T) {
|
||||||
|
|
||||||
rdr, err := newFileReader(io.NopCloser(reader))
|
rdr, err := newFileReader(io.NopCloser(reader))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
archiveChan := make(chan []byte)
|
archiveChan := make(chan []byte)
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ func TestHandleNonArchiveFile(t *testing.T) {
|
||||||
|
|
||||||
rdr, err := newFileReader(file)
|
rdr, err := newFileReader(file)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
handler := newDefaultHandler(defaultHandlerType)
|
handler := newDefaultHandler(defaultHandlerType)
|
||||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||||
|
|
|
@ -59,7 +59,7 @@ func newMimeTypeReaderFromFileReader(r fileReader) mimeTypeReader {
|
||||||
|
|
||||||
// newMimeTypeReader creates a new mimeTypeReader from an io.Reader.
|
// newMimeTypeReader creates a new mimeTypeReader from an io.Reader.
|
||||||
// It uses a bufio.Reader to perform MIME type detection on the input reader
|
// It uses a bufio.Reader to perform MIME type detection on the input reader
|
||||||
// without consuming it, by peeking into the first 512 bytes of the input.
|
// without consuming it, by peeking into the first 3072 bytes of the input.
|
||||||
// This encapsulates both the original reader and the detected MIME type information.
|
// This encapsulates both the original reader and the detected MIME type information.
|
||||||
// This function is particularly useful for specialized archive handlers
|
// This function is particularly useful for specialized archive handlers
|
||||||
// that need to pass extracted content to the default handler without modifying the original reader.
|
// that need to pass extracted content to the default handler without modifying the original reader.
|
||||||
|
@ -84,10 +84,6 @@ func newFileReader(r io.Reader) (fileReader, error) {
|
||||||
|
|
||||||
fReader.BufferedReadSeeker = iobuf.NewBufferedReaderSeeker(r)
|
fReader.BufferedReadSeeker = iobuf.NewBufferedReaderSeeker(r)
|
||||||
|
|
||||||
// Disable buffering after initial reads.
|
|
||||||
// This optimization ensures we don't continue writing to the buffer after the initial reads.
|
|
||||||
defer fReader.DisableBuffering()
|
|
||||||
|
|
||||||
mime, err := mimetype.DetectReader(fReader)
|
mime, err := mimetype.DetectReader(fReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fReader, fmt.Errorf("unable to detect MIME type: %w", err)
|
return fReader, fmt.Errorf("unable to detect MIME type: %w", err)
|
||||||
|
@ -281,6 +277,7 @@ func HandleFile(
|
||||||
}
|
}
|
||||||
return fmt.Errorf("error creating custom reader: %w", err)
|
return fmt.Errorf("error creating custom reader: %w", err)
|
||||||
}
|
}
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
mimeT := mimeType(rdr.mime.String())
|
mimeT := mimeType(rdr.mime.String())
|
||||||
config := newFileHandlingConfig(options...)
|
config := newFileHandlingConfig(options...)
|
||||||
|
|
|
@ -27,21 +27,101 @@ func TestHandleFileCancelledContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleFile(t *testing.T) {
|
func TestHandleFile(t *testing.T) {
|
||||||
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 2)}
|
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 513)}
|
||||||
|
|
||||||
// Only one chunk is sent on the channel.
|
// Only one chunk is sent on the channel.
|
||||||
// TODO: Embed a zip without making an HTTP request.
|
// TODO: Embed a zip without making an HTTP request.
|
||||||
resp, err := http.Get("https://raw.githubusercontent.com/bill-rich/bad-secrets/master/aws-canary-creds.zip")
|
resp, err := http.Get("https://raw.githubusercontent.com/bill-rich/bad-secrets/master/aws-canary-creds.zip")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if resp != nil && resp.Body != nil {
|
defer func() {
|
||||||
defer resp.Body.Close()
|
if resp != nil && resp.Body != nil {
|
||||||
}
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
assert.Equal(t, 0, len(reporter.Ch))
|
assert.Equal(t, 0, len(reporter.Ch))
|
||||||
assert.NoError(t, HandleFile(context.Background(), resp.Body, &sources.Chunk{}, reporter))
|
assert.NoError(t, HandleFile(context.Background(), resp.Body, &sources.Chunk{}, reporter))
|
||||||
assert.Equal(t, 1, len(reporter.Ch))
|
assert.Equal(t, 1, len(reporter.Ch))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleHTTPJson(t *testing.T) {
|
||||||
|
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm_random_data.json")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
chunkCh := make(chan *sources.Chunk, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(chunkCh)
|
||||||
|
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wantCount := 513
|
||||||
|
count := 0
|
||||||
|
for range chunkCh {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
assert.Equal(t, wantCount, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleHTTPJsonZip(t *testing.T) {
|
||||||
|
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
chunkCh := make(chan *sources.Chunk, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(chunkCh)
|
||||||
|
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wantCount := 513
|
||||||
|
count := 0
|
||||||
|
for range chunkCh {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
assert.Equal(t, wantCount, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHandleHTTPJsonZip(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
func() {
|
||||||
|
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
chunkCh := make(chan *sources.Chunk, 1)
|
||||||
|
|
||||||
|
b.StartTimer()
|
||||||
|
go func() {
|
||||||
|
defer close(chunkCh)
|
||||||
|
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for range chunkCh {
|
||||||
|
}
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkHandleFile(b *testing.B) {
|
func BenchmarkHandleFile(b *testing.B) {
|
||||||
file, err := os.Open("testdata/test.tgz")
|
file, err := os.Open("testdata/test.tgz")
|
||||||
assert.Nil(b, err)
|
assert.Nil(b, err)
|
||||||
|
|
|
@ -20,6 +20,7 @@ func TestHandleRPMFile(t *testing.T) {
|
||||||
|
|
||||||
rdr, err := newFileReader(file)
|
rdr, err := newFileReader(file)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
defer rdr.Close()
|
||||||
|
|
||||||
handler := newRPMHandler()
|
handler := newRPMHandler()
|
||||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||||
|
|
|
@ -1,123 +1,144 @@
|
||||||
package iobuf
|
package iobuf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/buffer"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/pool"
|
||||||
|
"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultBufferSize = 1 << 16 // 64KB
|
||||||
|
|
||||||
|
var defaultBufferPool *pool.Pool
|
||||||
|
|
||||||
|
func init() { defaultBufferPool = pool.NewBufferPool(defaultBufferSize) }
|
||||||
|
|
||||||
// BufferedReadSeeker provides a buffered reading interface with seeking capabilities.
|
// BufferedReadSeeker provides a buffered reading interface with seeking capabilities.
|
||||||
// It wraps an io.Reader and optionally an io.Seeker, allowing for efficient
|
// It wraps an io.Reader and optionally an io.Seeker, allowing for efficient
|
||||||
// reading and seeking operations, even on non-seekable underlying readers.
|
// reading and seeking operations, even on non-seekable underlying readers.
|
||||||
|
//
|
||||||
|
// For small amounts of data, it uses an in-memory buffer (bytes.Buffer) to store
|
||||||
|
// read bytes. When the amount of data exceeds a specified threshold, it switches
|
||||||
|
// to disk-based buffering using a temporary file. This approach balances memory
|
||||||
|
// usage and performance, allowing efficient handling of both small and large data streams.
|
||||||
|
//
|
||||||
|
// The struct manages the transition between in-memory and disk-based buffering
|
||||||
|
// transparently, providing a seamless reading and seeking experience regardless
|
||||||
|
// of the underlying data size or the seekability of the original reader.
|
||||||
|
//
|
||||||
|
// If the underlying reader is seekable, direct seeking operations are performed
|
||||||
|
// on it. For non-seekable readers, seeking is emulated using the buffer or
|
||||||
|
// temporary file.
|
||||||
type BufferedReadSeeker struct {
|
type BufferedReadSeeker struct {
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
seeker io.Seeker // If the reader supports seeking, it's stored here for direct access
|
seeker io.Seeker // If the reader supports seeking, it's stored here for direct access
|
||||||
|
|
||||||
buffer *bytes.Buffer // Internal buffer to store read bytes for non-seekable readers
|
bufPool *pool.Pool // Pool for storing buffers for reuse.
|
||||||
|
buf *buffer.Buffer // Buffer for storing data under the threshold in memory.
|
||||||
|
|
||||||
bytesRead int64 // Total number of bytes read from the underlying reader
|
bytesRead int64 // Total number of bytes read from the underlying reader
|
||||||
index int64 // Current position in the virtual stream
|
index int64 // Current position in the virtual stream
|
||||||
|
|
||||||
// Flag to control buffering. This flag is used to indicate whether buffering is active.
|
threshold int64 // Threshold for switching to file buffering
|
||||||
// Buffering is enabled during initial reads (e.g., for MIME type detection and format identification).
|
tempFile *os.File // Temporary file for disk-based buffering
|
||||||
// Once these operations are done, buffering should be disabled to prevent further writes to the buffer
|
tempFileName string // Name of the temporary file
|
||||||
// and to optimize subsequent reads directly from the underlying reader. This helps avoid excessive
|
diskBufferSize int64 // Size of data written to disk
|
||||||
// memory usage while still providing the necessary functionality for initial detection operations.
|
|
||||||
activeBuffering bool
|
// Fields to provide a quick way to determine the total size of the reader
|
||||||
|
// without having to seek.
|
||||||
|
totalSize int64 // Total size of the reader
|
||||||
|
sizeKnown bool // Whether the total size of the reader is known
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBufferedReaderSeeker creates and initializes a BufferedReadSeeker.
|
// NewBufferedReaderSeeker creates and initializes a BufferedReadSeeker.
|
||||||
// It takes an io.Reader and checks if it supports seeking.
|
// It takes an io.Reader and checks if it supports seeking.
|
||||||
// If the reader supports seeking, it is stored in the seeker field.
|
// If the reader supports seeking, it is stored in the seeker field.
|
||||||
func NewBufferedReaderSeeker(r io.Reader) *BufferedReadSeeker {
|
func NewBufferedReaderSeeker(r io.Reader) *BufferedReadSeeker {
|
||||||
var (
|
const defaultThreshold = 1 << 24 // 16MB threshold for switching to file buffering
|
||||||
seeker io.Seeker
|
|
||||||
buffer *bytes.Buffer
|
|
||||||
activeBuffering = true
|
|
||||||
)
|
|
||||||
if s, ok := r.(io.Seeker); ok {
|
|
||||||
seeker = s
|
|
||||||
activeBuffering = false
|
|
||||||
}
|
|
||||||
|
|
||||||
const mimeTypeBufferSize = 3072 // Approx buffer size for MIME type detection
|
seeker, _ := r.(io.Seeker)
|
||||||
|
|
||||||
|
var buf *buffer.Buffer
|
||||||
if seeker == nil {
|
if seeker == nil {
|
||||||
buffer = bytes.NewBuffer(make([]byte, 0, mimeTypeBufferSize))
|
buf = defaultBufferPool.Get()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &BufferedReadSeeker{
|
return &BufferedReadSeeker{
|
||||||
reader: r,
|
reader: r,
|
||||||
seeker: seeker,
|
seeker: seeker,
|
||||||
buffer: buffer,
|
bufPool: defaultBufferPool,
|
||||||
bytesRead: 0,
|
buf: buf,
|
||||||
index: 0,
|
threshold: defaultThreshold,
|
||||||
activeBuffering: activeBuffering,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads len(out) bytes from the reader starting at the current index.
|
// Read reads len(out) bytes from the reader starting at the current index.
|
||||||
// It handles both seekable and non-seekable underlying readers efficiently.
|
// It handles both seekable and non-seekable underlying readers efficiently.
|
||||||
func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
|
func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
|
||||||
// For seekable readers, read directly from the underlying reader.
|
|
||||||
if br.seeker != nil {
|
if br.seeker != nil {
|
||||||
|
// For seekable readers, read directly from the underlying reader.
|
||||||
n, err := br.reader.Read(out)
|
n, err := br.reader.Read(out)
|
||||||
br.index += int64(n)
|
if n > 0 {
|
||||||
br.bytesRead = max(br.bytesRead, br.index)
|
br.bytesRead += int64(n)
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For non-seekable readers, use buffered reading.
|
|
||||||
outLen := int64(len(out))
|
|
||||||
if outLen == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the current read position (br.index) is within the buffer's valid data range,
|
|
||||||
// read from the buffer. This ensures previously read data (e.g., for mime type detection)
|
|
||||||
// is included in subsequent reads, providing a consistent view of the reader's content.
|
|
||||||
if br.index < int64(br.buffer.Len()) {
|
|
||||||
n := copy(out, br.buffer.Bytes()[br.index:])
|
|
||||||
br.index += int64(n)
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !br.activeBuffering {
|
|
||||||
// If buffering is not active, read directly from the underlying reader.
|
|
||||||
n, err := br.reader.Read(out)
|
|
||||||
br.index += int64(n)
|
|
||||||
br.bytesRead = max(br.bytesRead, br.index)
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure there are enough bytes in the buffer to read from.
|
|
||||||
if outLen+br.index > int64(br.buffer.Len()) {
|
|
||||||
bytesToRead := int(outLen + br.index - int64(br.buffer.Len()))
|
|
||||||
readerBytes := make([]byte, bytesToRead)
|
|
||||||
n, err := br.reader.Read(readerBytes)
|
|
||||||
br.buffer.Write(readerBytes[:n])
|
|
||||||
br.bytesRead += int64(n)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the read does not exceed the buffer length.
|
var (
|
||||||
endIndex := br.index + outLen
|
totalBytesRead int
|
||||||
bufLen := int64(br.buffer.Len())
|
err error
|
||||||
if endIndex > bufLen {
|
)
|
||||||
endIndex = bufLen
|
|
||||||
|
// If the current read position is within the in-memory buffer.
|
||||||
|
if br.index < int64(br.buf.Len()) {
|
||||||
|
totalBytesRead = copy(out, br.buf.Bytes()[br.index:])
|
||||||
|
br.index += int64(totalBytesRead)
|
||||||
|
if totalBytesRead == len(out) {
|
||||||
|
return totalBytesRead, nil
|
||||||
|
}
|
||||||
|
out = out[totalBytesRead:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if br.index >= bufLen {
|
// If we've exceeded the in-memory threshold and have a temp file.
|
||||||
return 0, io.EOF
|
if br.tempFile != nil && br.index < br.diskBufferSize {
|
||||||
|
if _, err := br.tempFile.Seek(br.index-int64(br.buf.Len()), io.SeekStart); err != nil {
|
||||||
|
return totalBytesRead, err
|
||||||
|
}
|
||||||
|
m, err := br.tempFile.Read(out)
|
||||||
|
totalBytesRead += m
|
||||||
|
br.index += int64(m)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return totalBytesRead, err
|
||||||
|
}
|
||||||
|
if totalBytesRead == len(out) {
|
||||||
|
return totalBytesRead, nil
|
||||||
|
}
|
||||||
|
out = out[totalBytesRead:]
|
||||||
}
|
}
|
||||||
|
|
||||||
n := copy(out, br.buffer.Bytes()[br.index:endIndex])
|
if len(out) == 0 {
|
||||||
br.index += int64(n)
|
return totalBytesRead, nil
|
||||||
return n, nil
|
}
|
||||||
|
|
||||||
|
// If we still need to read more data.
|
||||||
|
var raderBytes int
|
||||||
|
raderBytes, err = br.reader.Read(out)
|
||||||
|
totalBytesRead += raderBytes
|
||||||
|
br.index += int64(raderBytes)
|
||||||
|
|
||||||
|
if writeErr := br.writeData(out[:raderBytes]); writeErr != nil {
|
||||||
|
return totalBytesRead, writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
br.totalSize = br.bytesRead
|
||||||
|
br.sizeKnown = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalBytesRead, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seek sets the offset for the next Read or Write to offset.
|
// Seek sets the offset for the next Read or Write to offset.
|
||||||
|
@ -125,45 +146,25 @@ func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
|
||||||
func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
if br.seeker != nil {
|
if br.seeker != nil {
|
||||||
// Use the underlying Seeker if available.
|
// Use the underlying Seeker if available.
|
||||||
newIndex, err := br.seeker.Seek(offset, whence)
|
return br.seeker.Seek(offset, whence)
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error seeking in reader: %w", err)
|
|
||||||
}
|
|
||||||
if newIndex > br.bytesRead {
|
|
||||||
br.bytesRead = newIndex
|
|
||||||
}
|
|
||||||
return newIndex, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manual seeking for non-seekable readers.
|
// Manual seeking for non-seekable readers.
|
||||||
newIndex := br.index
|
newIndex := br.index
|
||||||
|
|
||||||
const bufferSize = 64 * 1024 // 64KB chunk size for reading
|
|
||||||
|
|
||||||
switch whence {
|
switch whence {
|
||||||
case io.SeekStart:
|
case io.SeekStart:
|
||||||
newIndex = offset
|
newIndex = offset
|
||||||
case io.SeekCurrent:
|
case io.SeekCurrent:
|
||||||
newIndex += offset
|
newIndex += offset
|
||||||
case io.SeekEnd:
|
case io.SeekEnd:
|
||||||
// Read the entire reader to determine its length
|
// If we already know the total size, we can use it directly.
|
||||||
buffer := make([]byte, bufferSize)
|
if !br.sizeKnown {
|
||||||
for {
|
if err := br.readToEnd(); err != nil {
|
||||||
n, err := br.reader.Read(buffer)
|
|
||||||
if n > 0 {
|
|
||||||
if br.activeBuffering {
|
|
||||||
br.buffer.Write(buffer[:n])
|
|
||||||
}
|
|
||||||
br.bytesRead += int64(n)
|
|
||||||
}
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newIndex = min(br.bytesRead+offset, br.bytesRead)
|
newIndex = br.totalSize + offset
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("invalid whence value")
|
return 0, errors.New("invalid whence value")
|
||||||
}
|
}
|
||||||
|
@ -172,27 +173,161 @@ func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
return 0, errors.New("can not seek to before start of reader")
|
return 0, errors.New("can not seek to before start of reader")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For non-seekable readers, we need to ensure we've read up to the new index.
|
||||||
|
if br.seeker == nil && newIndex > br.bytesRead {
|
||||||
|
if err := br.readUntil(newIndex); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
br.index = newIndex
|
br.index = newIndex
|
||||||
|
|
||||||
|
// Update bytesRead only if we've moved beyond what we've read so far.
|
||||||
|
if br.index > br.bytesRead {
|
||||||
|
br.bytesRead = br.index
|
||||||
|
}
|
||||||
|
|
||||||
return newIndex, nil
|
return newIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (br *BufferedReadSeeker) readToEnd() error {
|
||||||
|
buf := br.bufPool.Get()
|
||||||
|
defer br.bufPool.Put(buf)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := io.CopyN(buf, br.reader, defaultBufferSize)
|
||||||
|
if n > 0 {
|
||||||
|
// Write the data from the buffer.
|
||||||
|
if writeErr := br.writeData(buf.Bytes()[:n]); writeErr != nil {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Reset the buffer for the next iteration.
|
||||||
|
buf.Reset()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
br.totalSize = br.bytesRead
|
||||||
|
br.sizeKnown = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *BufferedReadSeeker) writeData(data []byte) error {
|
||||||
|
_, err := br.buf.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
br.bytesRead += int64(len(data))
|
||||||
|
|
||||||
|
// Check if we've reached or exceeded the threshold.
|
||||||
|
if br.buf.Len() < int(br.threshold) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.tempFile == nil {
|
||||||
|
if err := br.createTempFile(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush the buffer to disk.
|
||||||
|
return br.flushBufferToDisk()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *BufferedReadSeeker) readUntil(index int64) error {
|
||||||
|
buf := br.bufPool.Get()
|
||||||
|
defer br.bufPool.Put(buf)
|
||||||
|
|
||||||
|
for br.bytesRead < index {
|
||||||
|
remaining := index - br.bytesRead
|
||||||
|
bufSize := int64(defaultBufferSize)
|
||||||
|
if remaining < bufSize {
|
||||||
|
bufSize = remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := io.CopyN(buf, br, bufSize)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *BufferedReadSeeker) createTempFile() error {
|
||||||
|
tempFile, err := os.CreateTemp(os.TempDir(), cleantemp.MkFilename())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
br.tempFile = tempFile
|
||||||
|
br.tempFileName = tempFile.Name()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *BufferedReadSeeker) flushBufferToDisk() error {
|
||||||
|
if _, err := br.buf.WriteTo(br.tempFile); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
br.diskBufferSize = int64(br.buf.Len())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadAt reads len(out) bytes into out starting at offset off in the underlying input source.
|
// ReadAt reads len(out) bytes into out starting at offset off in the underlying input source.
|
||||||
// It uses Seek and Read to implement random access reading.
|
// It uses Seek and Read to implement random access reading.
|
||||||
func (br *BufferedReadSeeker) ReadAt(out []byte, offset int64) (int, error) {
|
func (br *BufferedReadSeeker) ReadAt(out []byte, offset int64) (int, error) {
|
||||||
startIndex, err := br.Seek(offset, io.SeekStart)
|
if br.seeker != nil {
|
||||||
if err != nil {
|
// Use the underlying Seeker if available.
|
||||||
|
_, err := br.Seek(offset, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return br.Read(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-seekable readers, use our buffering logic.
|
||||||
|
currentIndex := br.index
|
||||||
|
|
||||||
|
if _, err := br.Seek(offset, io.SeekStart); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if startIndex != offset {
|
n, err := br.Read(out)
|
||||||
return 0, io.EOF
|
if err != nil {
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return br.Read(out)
|
// Seek back to the original position.
|
||||||
|
if _, err = br.Seek(currentIndex, io.SeekStart); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisableBuffering stops the buffering process.
|
// Close closes the BufferedReadSeeker and releases any resources used.
|
||||||
// This is useful after initial reads (e.g., for MIME type detection and format identification)
|
// It closes the temporary file if one was created and removes it from disk and
|
||||||
// to prevent further writes to the buffer, optimizing subsequent reads.
|
// returns the buffer to the pool.
|
||||||
func (br *BufferedReadSeeker) DisableBuffering() { br.activeBuffering = false }
|
func (br *BufferedReadSeeker) Close() error {
|
||||||
|
if br.buf != nil {
|
||||||
|
br.bufPool.Put(br.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.tempFile != nil {
|
||||||
|
br.tempFile.Close()
|
||||||
|
return os.Remove(br.tempFileName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
activeBuffering bool
|
|
||||||
reads []int
|
reads []int
|
||||||
expectedReads []int
|
expectedReads []int
|
||||||
expectedBytes [][]byte
|
expectedBytes [][]byte
|
||||||
|
@ -25,7 +24,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read from seekable reader",
|
name: "read from seekable reader",
|
||||||
reader: strings.NewReader("test data"),
|
reader: strings.NewReader("test data"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{4},
|
expectedReads: []int{4},
|
||||||
expectedBytes: [][]byte{[]byte("test")},
|
expectedBytes: [][]byte{[]byte("test")},
|
||||||
|
@ -35,7 +33,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read from non-seekable reader with buffering",
|
name: "read from non-seekable reader with buffering",
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{4},
|
expectedReads: []int{4},
|
||||||
expectedBytes: [][]byte{[]byte("test")},
|
expectedBytes: [][]byte{[]byte("test")},
|
||||||
|
@ -46,7 +43,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read from non-seekable reader without buffering",
|
name: "read from non-seekable reader without buffering",
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
activeBuffering: false,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{4},
|
expectedReads: []int{4},
|
||||||
expectedBytes: [][]byte{[]byte("test")},
|
expectedBytes: [][]byte{[]byte("test")},
|
||||||
|
@ -56,7 +52,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read beyond buffer",
|
name: "read beyond buffer",
|
||||||
reader: strings.NewReader("test data"),
|
reader: strings.NewReader("test data"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{10},
|
reads: []int{10},
|
||||||
expectedReads: []int{9},
|
expectedReads: []int{9},
|
||||||
expectedBytes: [][]byte{[]byte("test data")},
|
expectedBytes: [][]byte{[]byte("test data")},
|
||||||
|
@ -66,7 +61,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read with empty reader",
|
name: "read with empty reader",
|
||||||
reader: strings.NewReader(""),
|
reader: strings.NewReader(""),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{0},
|
expectedReads: []int{0},
|
||||||
expectedBytes: [][]byte{[]byte("")},
|
expectedBytes: [][]byte{[]byte("")},
|
||||||
|
@ -77,7 +71,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read exact buffer size",
|
name: "read exact buffer size",
|
||||||
reader: strings.NewReader("test"),
|
reader: strings.NewReader("test"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{4},
|
expectedReads: []int{4},
|
||||||
expectedBytes: [][]byte{[]byte("test")},
|
expectedBytes: [][]byte{[]byte("test")},
|
||||||
|
@ -87,7 +80,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read less than buffer size",
|
name: "read less than buffer size",
|
||||||
reader: strings.NewReader("te"),
|
reader: strings.NewReader("te"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{2},
|
expectedReads: []int{2},
|
||||||
expectedBytes: [][]byte{[]byte("te")},
|
expectedBytes: [][]byte{[]byte("te")},
|
||||||
|
@ -97,7 +89,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "read more than buffer size without buffering",
|
name: "read more than buffer size without buffering",
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
activeBuffering: false,
|
|
||||||
reads: []int{4},
|
reads: []int{4},
|
||||||
expectedReads: []int{4},
|
expectedReads: []int{4},
|
||||||
expectedBytes: [][]byte{[]byte("test")},
|
expectedBytes: [][]byte{[]byte("test")},
|
||||||
|
@ -107,7 +98,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "multiple reads with buffering",
|
name: "multiple reads with buffering",
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
activeBuffering: true,
|
|
||||||
reads: []int{4, 5},
|
reads: []int{4, 5},
|
||||||
expectedReads: []int{4, 5},
|
expectedReads: []int{4, 5},
|
||||||
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
||||||
|
@ -118,7 +108,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "multiple reads without buffering",
|
name: "multiple reads without buffering",
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
activeBuffering: false,
|
|
||||||
reads: []int{4, 5},
|
reads: []int{4, 5},
|
||||||
expectedReads: []int{4, 5},
|
expectedReads: []int{4, 5},
|
||||||
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
||||||
|
@ -132,7 +121,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
brs := NewBufferedReaderSeeker(tt.reader)
|
brs := NewBufferedReaderSeeker(tt.reader)
|
||||||
brs.activeBuffering = tt.activeBuffering
|
|
||||||
|
|
||||||
for i, readSize := range tt.reads {
|
for i, readSize := range tt.reads {
|
||||||
buf := make([]byte, readSize)
|
buf := make([]byte, readSize)
|
||||||
|
@ -151,10 +139,12 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedBytesRead, brs.bytesRead)
|
assert.Equal(t, tt.expectedBytesRead, brs.bytesRead)
|
||||||
assert.Equal(t, tt.expectedIndex, brs.index)
|
if brs.seeker == nil {
|
||||||
|
assert.Equal(t, tt.expectedIndex, brs.index)
|
||||||
|
}
|
||||||
|
|
||||||
if brs.buffer != nil && len(tt.expectedBuffer) > 0 {
|
if brs.buf != nil && len(tt.expectedBuffer) > 0 {
|
||||||
assert.Equal(t, tt.expectedBuffer, brs.buffer.Bytes())
|
assert.Equal(t, tt.expectedBuffer, brs.buf.Bytes())
|
||||||
} else {
|
} else {
|
||||||
assert.Nil(t, tt.expectedBuffer)
|
assert.Nil(t, tt.expectedBuffer)
|
||||||
}
|
}
|
||||||
|
@ -240,7 +230,7 @@ func TestBufferedReaderSeekerSeek(t *testing.T) {
|
||||||
reader: bytes.NewBufferString("test data"),
|
reader: bytes.NewBufferString("test data"),
|
||||||
offset: 20,
|
offset: 20,
|
||||||
whence: io.SeekEnd,
|
whence: io.SeekEnd,
|
||||||
expectedPos: 9,
|
expectedPos: 29,
|
||||||
expectedErr: false,
|
expectedErr: false,
|
||||||
expectedRead: []byte{},
|
expectedRead: []byte{},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in a new issue