mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[bug]- Invalid Seek for Non-Seekable Readers (#3095)
* inital work * fix and add tests * uncomment * fix seek end * use buffer pool * revert timeout * make linter happy * More linting :()
This commit is contained in:
parent
4a8b213651
commit
ebfbd21707
9 changed files with 344 additions and 134 deletions
|
@ -20,6 +20,7 @@ func TestHandleARFile(t *testing.T) {
|
|||
|
||||
rdr, err := newFileReader(file)
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
handler := newARHandler()
|
||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||
|
|
|
@ -111,6 +111,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
|
|||
}
|
||||
return fmt.Errorf("error creating custom reader: %w", err)
|
||||
}
|
||||
defer rdr.Close()
|
||||
|
||||
return h.openArchive(ctx, depth+1, rdr, archiveChan)
|
||||
case archiver.Extractor:
|
||||
|
@ -194,6 +195,7 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
|
|||
}
|
||||
return fmt.Errorf("error creating custom reader: %w", err)
|
||||
}
|
||||
defer rdr.Close()
|
||||
|
||||
h.metrics.incFilesProcessed()
|
||||
h.metrics.observeFileSize(fileSize)
|
||||
|
|
|
@ -89,6 +89,8 @@ func TestArchiveHandler(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("error creating reusable reader: %s", err)
|
||||
}
|
||||
defer newReader.Close()
|
||||
|
||||
archiveChan, err := handler.HandleFile(logContext.Background(), newReader)
|
||||
if testCase.expectErr {
|
||||
assert.NoError(t, err)
|
||||
|
@ -119,6 +121,7 @@ func TestOpenInvalidArchive(t *testing.T) {
|
|||
|
||||
rdr, err := newFileReader(io.NopCloser(reader))
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
archiveChan := make(chan []byte)
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ func TestHandleNonArchiveFile(t *testing.T) {
|
|||
|
||||
rdr, err := newFileReader(file)
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
handler := newDefaultHandler(defaultHandlerType)
|
||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||
|
|
|
@ -59,7 +59,7 @@ func newMimeTypeReaderFromFileReader(r fileReader) mimeTypeReader {
|
|||
|
||||
// newMimeTypeReader creates a new mimeTypeReader from an io.Reader.
|
||||
// It uses a bufio.Reader to perform MIME type detection on the input reader
|
||||
// without consuming it, by peeking into the first 512 bytes of the input.
|
||||
// without consuming it, by peeking into the first 3072 bytes of the input.
|
||||
// This encapsulates both the original reader and the detected MIME type information.
|
||||
// This function is particularly useful for specialized archive handlers
|
||||
// that need to pass extracted content to the default handler without modifying the original reader.
|
||||
|
@ -84,10 +84,6 @@ func newFileReader(r io.Reader) (fileReader, error) {
|
|||
|
||||
fReader.BufferedReadSeeker = iobuf.NewBufferedReaderSeeker(r)
|
||||
|
||||
// Disable buffering after initial reads.
|
||||
// This optimization ensures we don't continue writing to the buffer after the initial reads.
|
||||
defer fReader.DisableBuffering()
|
||||
|
||||
mime, err := mimetype.DetectReader(fReader)
|
||||
if err != nil {
|
||||
return fReader, fmt.Errorf("unable to detect MIME type: %w", err)
|
||||
|
@ -281,6 +277,7 @@ func HandleFile(
|
|||
}
|
||||
return fmt.Errorf("error creating custom reader: %w", err)
|
||||
}
|
||||
defer rdr.Close()
|
||||
|
||||
mimeT := mimeType(rdr.mime.String())
|
||||
config := newFileHandlingConfig(options...)
|
||||
|
|
|
@ -27,21 +27,101 @@ func TestHandleFileCancelledContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandleFile(t *testing.T) {
|
||||
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 2)}
|
||||
reporter := sources.ChanReporter{Ch: make(chan *sources.Chunk, 513)}
|
||||
|
||||
// Only one chunk is sent on the channel.
|
||||
// TODO: Embed a zip without making an HTTP request.
|
||||
resp, err := http.Get("https://raw.githubusercontent.com/bill-rich/bad-secrets/master/aws-canary-creds.zip")
|
||||
assert.NoError(t, err)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
assert.Equal(t, 0, len(reporter.Ch))
|
||||
assert.NoError(t, HandleFile(context.Background(), resp.Body, &sources.Chunk{}, reporter))
|
||||
assert.Equal(t, 1, len(reporter.Ch))
|
||||
}
|
||||
|
||||
func TestHandleHTTPJson(t *testing.T) {
|
||||
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm_random_data.json")
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
chunkCh := make(chan *sources.Chunk, 1)
|
||||
go func() {
|
||||
defer close(chunkCh)
|
||||
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wantCount := 513
|
||||
count := 0
|
||||
for range chunkCh {
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, wantCount, count)
|
||||
}
|
||||
|
||||
func TestHandleHTTPJsonZip(t *testing.T) {
|
||||
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
chunkCh := make(chan *sources.Chunk, 1)
|
||||
go func() {
|
||||
defer close(chunkCh)
|
||||
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wantCount := 513
|
||||
count := 0
|
||||
for range chunkCh {
|
||||
count++
|
||||
}
|
||||
assert.Equal(t, wantCount, count)
|
||||
}
|
||||
|
||||
func BenchmarkHandleHTTPJsonZip(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
func() {
|
||||
resp, err := http.Get("https://raw.githubusercontent.com/ahrav/nothing-to-see-here/main/sm.zip")
|
||||
assert.NoError(b, err)
|
||||
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
chunkCh := make(chan *sources.Chunk, 1)
|
||||
|
||||
b.StartTimer()
|
||||
go func() {
|
||||
defer close(chunkCh)
|
||||
err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
|
||||
assert.NoError(b, err)
|
||||
}()
|
||||
|
||||
for range chunkCh {
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHandleFile(b *testing.B) {
|
||||
file, err := os.Open("testdata/test.tgz")
|
||||
assert.Nil(b, err)
|
||||
|
|
|
@ -20,6 +20,7 @@ func TestHandleRPMFile(t *testing.T) {
|
|||
|
||||
rdr, err := newFileReader(file)
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
handler := newRPMHandler()
|
||||
archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr)
|
||||
|
|
|
@ -1,123 +1,144 @@
|
|||
package iobuf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/buffer"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/buffers/pool"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
|
||||
)
|
||||
|
||||
const defaultBufferSize = 1 << 16 // 64KB
|
||||
|
||||
var defaultBufferPool *pool.Pool
|
||||
|
||||
func init() { defaultBufferPool = pool.NewBufferPool(defaultBufferSize) }
|
||||
|
||||
// BufferedReadSeeker provides a buffered reading interface with seeking capabilities.
|
||||
// It wraps an io.Reader and optionally an io.Seeker, allowing for efficient
|
||||
// reading and seeking operations, even on non-seekable underlying readers.
|
||||
//
|
||||
// For small amounts of data, it uses an in-memory buffer (bytes.Buffer) to store
|
||||
// read bytes. When the amount of data exceeds a specified threshold, it switches
|
||||
// to disk-based buffering using a temporary file. This approach balances memory
|
||||
// usage and performance, allowing efficient handling of both small and large data streams.
|
||||
//
|
||||
// The struct manages the transition between in-memory and disk-based buffering
|
||||
// transparently, providing a seamless reading and seeking experience regardless
|
||||
// of the underlying data size or the seekability of the original reader.
|
||||
//
|
||||
// If the underlying reader is seekable, direct seeking operations are performed
|
||||
// on it. For non-seekable readers, seeking is emulated using the buffer or
|
||||
// temporary file.
|
||||
type BufferedReadSeeker struct {
|
||||
reader io.Reader
|
||||
seeker io.Seeker // If the reader supports seeking, it's stored here for direct access
|
||||
|
||||
buffer *bytes.Buffer // Internal buffer to store read bytes for non-seekable readers
|
||||
bufPool *pool.Pool // Pool for storing buffers for reuse.
|
||||
buf *buffer.Buffer // Buffer for storing data under the threshold in memory.
|
||||
|
||||
bytesRead int64 // Total number of bytes read from the underlying reader
|
||||
index int64 // Current position in the virtual stream
|
||||
|
||||
// Flag to control buffering. This flag is used to indicate whether buffering is active.
|
||||
// Buffering is enabled during initial reads (e.g., for MIME type detection and format identification).
|
||||
// Once these operations are done, buffering should be disabled to prevent further writes to the buffer
|
||||
// and to optimize subsequent reads directly from the underlying reader. This helps avoid excessive
|
||||
// memory usage while still providing the necessary functionality for initial detection operations.
|
||||
activeBuffering bool
|
||||
threshold int64 // Threshold for switching to file buffering
|
||||
tempFile *os.File // Temporary file for disk-based buffering
|
||||
tempFileName string // Name of the temporary file
|
||||
diskBufferSize int64 // Size of data written to disk
|
||||
|
||||
// Fields to provide a quick way to determine the total size of the reader
|
||||
// without having to seek.
|
||||
totalSize int64 // Total size of the reader
|
||||
sizeKnown bool // Whether the total size of the reader is known
|
||||
}
|
||||
|
||||
// NewBufferedReaderSeeker creates and initializes a BufferedReadSeeker.
|
||||
// It takes an io.Reader and checks if it supports seeking.
|
||||
// If the reader supports seeking, it is stored in the seeker field.
|
||||
func NewBufferedReaderSeeker(r io.Reader) *BufferedReadSeeker {
|
||||
var (
|
||||
seeker io.Seeker
|
||||
buffer *bytes.Buffer
|
||||
activeBuffering = true
|
||||
)
|
||||
if s, ok := r.(io.Seeker); ok {
|
||||
seeker = s
|
||||
activeBuffering = false
|
||||
}
|
||||
const defaultThreshold = 1 << 24 // 16MB threshold for switching to file buffering
|
||||
|
||||
const mimeTypeBufferSize = 3072 // Approx buffer size for MIME type detection
|
||||
seeker, _ := r.(io.Seeker)
|
||||
|
||||
var buf *buffer.Buffer
|
||||
if seeker == nil {
|
||||
buffer = bytes.NewBuffer(make([]byte, 0, mimeTypeBufferSize))
|
||||
buf = defaultBufferPool.Get()
|
||||
}
|
||||
|
||||
return &BufferedReadSeeker{
|
||||
reader: r,
|
||||
seeker: seeker,
|
||||
buffer: buffer,
|
||||
bytesRead: 0,
|
||||
index: 0,
|
||||
activeBuffering: activeBuffering,
|
||||
reader: r,
|
||||
seeker: seeker,
|
||||
bufPool: defaultBufferPool,
|
||||
buf: buf,
|
||||
threshold: defaultThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads len(out) bytes from the reader starting at the current index.
|
||||
// It handles both seekable and non-seekable underlying readers efficiently.
|
||||
func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
|
||||
// For seekable readers, read directly from the underlying reader.
|
||||
if br.seeker != nil {
|
||||
// For seekable readers, read directly from the underlying reader.
|
||||
n, err := br.reader.Read(out)
|
||||
br.index += int64(n)
|
||||
br.bytesRead = max(br.bytesRead, br.index)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// For non-seekable readers, use buffered reading.
|
||||
outLen := int64(len(out))
|
||||
if outLen == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// If the current read position (br.index) is within the buffer's valid data range,
|
||||
// read from the buffer. This ensures previously read data (e.g., for mime type detection)
|
||||
// is included in subsequent reads, providing a consistent view of the reader's content.
|
||||
if br.index < int64(br.buffer.Len()) {
|
||||
n := copy(out, br.buffer.Bytes()[br.index:])
|
||||
br.index += int64(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if !br.activeBuffering {
|
||||
// If buffering is not active, read directly from the underlying reader.
|
||||
n, err := br.reader.Read(out)
|
||||
br.index += int64(n)
|
||||
br.bytesRead = max(br.bytesRead, br.index)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Ensure there are enough bytes in the buffer to read from.
|
||||
if outLen+br.index > int64(br.buffer.Len()) {
|
||||
bytesToRead := int(outLen + br.index - int64(br.buffer.Len()))
|
||||
readerBytes := make([]byte, bytesToRead)
|
||||
n, err := br.reader.Read(readerBytes)
|
||||
br.buffer.Write(readerBytes[:n])
|
||||
br.bytesRead += int64(n)
|
||||
|
||||
if err != nil {
|
||||
return n, err
|
||||
if n > 0 {
|
||||
br.bytesRead += int64(n)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Ensure the read does not exceed the buffer length.
|
||||
endIndex := br.index + outLen
|
||||
bufLen := int64(br.buffer.Len())
|
||||
if endIndex > bufLen {
|
||||
endIndex = bufLen
|
||||
var (
|
||||
totalBytesRead int
|
||||
err error
|
||||
)
|
||||
|
||||
// If the current read position is within the in-memory buffer.
|
||||
if br.index < int64(br.buf.Len()) {
|
||||
totalBytesRead = copy(out, br.buf.Bytes()[br.index:])
|
||||
br.index += int64(totalBytesRead)
|
||||
if totalBytesRead == len(out) {
|
||||
return totalBytesRead, nil
|
||||
}
|
||||
out = out[totalBytesRead:]
|
||||
}
|
||||
|
||||
if br.index >= bufLen {
|
||||
return 0, io.EOF
|
||||
// If we've exceeded the in-memory threshold and have a temp file.
|
||||
if br.tempFile != nil && br.index < br.diskBufferSize {
|
||||
if _, err := br.tempFile.Seek(br.index-int64(br.buf.Len()), io.SeekStart); err != nil {
|
||||
return totalBytesRead, err
|
||||
}
|
||||
m, err := br.tempFile.Read(out)
|
||||
totalBytesRead += m
|
||||
br.index += int64(m)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return totalBytesRead, err
|
||||
}
|
||||
if totalBytesRead == len(out) {
|
||||
return totalBytesRead, nil
|
||||
}
|
||||
out = out[totalBytesRead:]
|
||||
}
|
||||
|
||||
n := copy(out, br.buffer.Bytes()[br.index:endIndex])
|
||||
br.index += int64(n)
|
||||
return n, nil
|
||||
if len(out) == 0 {
|
||||
return totalBytesRead, nil
|
||||
}
|
||||
|
||||
// If we still need to read more data.
|
||||
var raderBytes int
|
||||
raderBytes, err = br.reader.Read(out)
|
||||
totalBytesRead += raderBytes
|
||||
br.index += int64(raderBytes)
|
||||
|
||||
if writeErr := br.writeData(out[:raderBytes]); writeErr != nil {
|
||||
return totalBytesRead, writeErr
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
br.totalSize = br.bytesRead
|
||||
br.sizeKnown = true
|
||||
}
|
||||
|
||||
return totalBytesRead, err
|
||||
}
|
||||
|
||||
// Seek sets the offset for the next Read or Write to offset.
|
||||
|
@ -125,45 +146,25 @@ func (br *BufferedReadSeeker) Read(out []byte) (int, error) {
|
|||
func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||
if br.seeker != nil {
|
||||
// Use the underlying Seeker if available.
|
||||
newIndex, err := br.seeker.Seek(offset, whence)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error seeking in reader: %w", err)
|
||||
}
|
||||
if newIndex > br.bytesRead {
|
||||
br.bytesRead = newIndex
|
||||
}
|
||||
return newIndex, nil
|
||||
return br.seeker.Seek(offset, whence)
|
||||
}
|
||||
|
||||
// Manual seeking for non-seekable readers.
|
||||
newIndex := br.index
|
||||
|
||||
const bufferSize = 64 * 1024 // 64KB chunk size for reading
|
||||
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
newIndex = offset
|
||||
case io.SeekCurrent:
|
||||
newIndex += offset
|
||||
case io.SeekEnd:
|
||||
// Read the entire reader to determine its length
|
||||
buffer := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := br.reader.Read(buffer)
|
||||
if n > 0 {
|
||||
if br.activeBuffering {
|
||||
br.buffer.Write(buffer[:n])
|
||||
}
|
||||
br.bytesRead += int64(n)
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// If we already know the total size, we can use it directly.
|
||||
if !br.sizeKnown {
|
||||
if err := br.readToEnd(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
newIndex = min(br.bytesRead+offset, br.bytesRead)
|
||||
newIndex = br.totalSize + offset
|
||||
default:
|
||||
return 0, errors.New("invalid whence value")
|
||||
}
|
||||
|
@ -172,27 +173,161 @@ func (br *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
|||
return 0, errors.New("can not seek to before start of reader")
|
||||
}
|
||||
|
||||
// For non-seekable readers, we need to ensure we've read up to the new index.
|
||||
if br.seeker == nil && newIndex > br.bytesRead {
|
||||
if err := br.readUntil(newIndex); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
br.index = newIndex
|
||||
|
||||
// Update bytesRead only if we've moved beyond what we've read so far.
|
||||
if br.index > br.bytesRead {
|
||||
br.bytesRead = br.index
|
||||
}
|
||||
|
||||
return newIndex, nil
|
||||
}
|
||||
|
||||
func (br *BufferedReadSeeker) readToEnd() error {
|
||||
buf := br.bufPool.Get()
|
||||
defer br.bufPool.Put(buf)
|
||||
|
||||
for {
|
||||
n, err := io.CopyN(buf, br.reader, defaultBufferSize)
|
||||
if n > 0 {
|
||||
// Write the data from the buffer.
|
||||
if writeErr := br.writeData(buf.Bytes()[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
}
|
||||
// Reset the buffer for the next iteration.
|
||||
buf.Reset()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
br.totalSize = br.bytesRead
|
||||
br.sizeKnown = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (br *BufferedReadSeeker) writeData(data []byte) error {
|
||||
_, err := br.buf.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
br.bytesRead += int64(len(data))
|
||||
|
||||
// Check if we've reached or exceeded the threshold.
|
||||
if br.buf.Len() < int(br.threshold) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if br.tempFile == nil {
|
||||
if err := br.createTempFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Flush the buffer to disk.
|
||||
return br.flushBufferToDisk()
|
||||
}
|
||||
|
||||
func (br *BufferedReadSeeker) readUntil(index int64) error {
|
||||
buf := br.bufPool.Get()
|
||||
defer br.bufPool.Put(buf)
|
||||
|
||||
for br.bytesRead < index {
|
||||
remaining := index - br.bytesRead
|
||||
bufSize := int64(defaultBufferSize)
|
||||
if remaining < bufSize {
|
||||
bufSize = remaining
|
||||
}
|
||||
|
||||
n, err := io.CopyN(buf, br, bufSize)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (br *BufferedReadSeeker) createTempFile() error {
|
||||
tempFile, err := os.CreateTemp(os.TempDir(), cleantemp.MkFilename())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
br.tempFile = tempFile
|
||||
br.tempFileName = tempFile.Name()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (br *BufferedReadSeeker) flushBufferToDisk() error {
|
||||
if _, err := br.buf.WriteTo(br.tempFile); err != nil {
|
||||
return err
|
||||
}
|
||||
br.diskBufferSize = int64(br.buf.Len())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadAt reads len(out) bytes into out starting at offset off in the underlying input source.
|
||||
// It uses Seek and Read to implement random access reading.
|
||||
func (br *BufferedReadSeeker) ReadAt(out []byte, offset int64) (int, error) {
|
||||
startIndex, err := br.Seek(offset, io.SeekStart)
|
||||
if err != nil {
|
||||
if br.seeker != nil {
|
||||
// Use the underlying Seeker if available.
|
||||
_, err := br.Seek(offset, io.SeekStart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return br.Read(out)
|
||||
}
|
||||
|
||||
// For non-seekable readers, use our buffering logic.
|
||||
currentIndex := br.index
|
||||
|
||||
if _, err := br.Seek(offset, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if startIndex != offset {
|
||||
return 0, io.EOF
|
||||
n, err := br.Read(out)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return br.Read(out)
|
||||
// Seek back to the original position.
|
||||
if _, err = br.Seek(currentIndex, io.SeekStart); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// DisableBuffering stops the buffering process.
|
||||
// This is useful after initial reads (e.g., for MIME type detection and format identification)
|
||||
// to prevent further writes to the buffer, optimizing subsequent reads.
|
||||
func (br *BufferedReadSeeker) DisableBuffering() { br.activeBuffering = false }
|
||||
// Close closes the BufferedReadSeeker and releases any resources used.
|
||||
// It closes the temporary file if one was created and removes it from disk and
|
||||
// returns the buffer to the pool.
|
||||
func (br *BufferedReadSeeker) Close() error {
|
||||
if br.buf != nil {
|
||||
br.bufPool.Put(br.buf)
|
||||
}
|
||||
|
||||
if br.tempFile != nil {
|
||||
br.tempFile.Close()
|
||||
return os.Remove(br.tempFileName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
activeBuffering bool
|
||||
reads []int
|
||||
expectedReads []int
|
||||
expectedBytes [][]byte
|
||||
|
@ -25,7 +24,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read from seekable reader",
|
||||
reader: strings.NewReader("test data"),
|
||||
activeBuffering: true,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{4},
|
||||
expectedBytes: [][]byte{[]byte("test")},
|
||||
|
@ -35,7 +33,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read from non-seekable reader with buffering",
|
||||
reader: bytes.NewBufferString("test data"),
|
||||
activeBuffering: true,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{4},
|
||||
expectedBytes: [][]byte{[]byte("test")},
|
||||
|
@ -46,7 +43,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read from non-seekable reader without buffering",
|
||||
reader: bytes.NewBufferString("test data"),
|
||||
activeBuffering: false,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{4},
|
||||
expectedBytes: [][]byte{[]byte("test")},
|
||||
|
@ -56,7 +52,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read beyond buffer",
|
||||
reader: strings.NewReader("test data"),
|
||||
activeBuffering: true,
|
||||
reads: []int{10},
|
||||
expectedReads: []int{9},
|
||||
expectedBytes: [][]byte{[]byte("test data")},
|
||||
|
@ -66,7 +61,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read with empty reader",
|
||||
reader: strings.NewReader(""),
|
||||
activeBuffering: true,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{0},
|
||||
expectedBytes: [][]byte{[]byte("")},
|
||||
|
@ -77,7 +71,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read exact buffer size",
|
||||
reader: strings.NewReader("test"),
|
||||
activeBuffering: true,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{4},
|
||||
expectedBytes: [][]byte{[]byte("test")},
|
||||
|
@ -87,7 +80,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read less than buffer size",
|
||||
reader: strings.NewReader("te"),
|
||||
activeBuffering: true,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{2},
|
||||
expectedBytes: [][]byte{[]byte("te")},
|
||||
|
@ -97,7 +89,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "read more than buffer size without buffering",
|
||||
reader: bytes.NewBufferString("test data"),
|
||||
activeBuffering: false,
|
||||
reads: []int{4},
|
||||
expectedReads: []int{4},
|
||||
expectedBytes: [][]byte{[]byte("test")},
|
||||
|
@ -107,7 +98,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "multiple reads with buffering",
|
||||
reader: bytes.NewBufferString("test data"),
|
||||
activeBuffering: true,
|
||||
reads: []int{4, 5},
|
||||
expectedReads: []int{4, 5},
|
||||
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
||||
|
@ -118,7 +108,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
{
|
||||
name: "multiple reads without buffering",
|
||||
reader: bytes.NewBufferString("test data"),
|
||||
activeBuffering: false,
|
||||
reads: []int{4, 5},
|
||||
expectedReads: []int{4, 5},
|
||||
expectedBytes: [][]byte{[]byte("test"), []byte(" data")},
|
||||
|
@ -132,7 +121,6 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
brs := NewBufferedReaderSeeker(tt.reader)
|
||||
brs.activeBuffering = tt.activeBuffering
|
||||
|
||||
for i, readSize := range tt.reads {
|
||||
buf := make([]byte, readSize)
|
||||
|
@ -151,10 +139,12 @@ func TestBufferedReaderSeekerRead(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.Equal(t, tt.expectedBytesRead, brs.bytesRead)
|
||||
assert.Equal(t, tt.expectedIndex, brs.index)
|
||||
if brs.seeker == nil {
|
||||
assert.Equal(t, tt.expectedIndex, brs.index)
|
||||
}
|
||||
|
||||
if brs.buffer != nil && len(tt.expectedBuffer) > 0 {
|
||||
assert.Equal(t, tt.expectedBuffer, brs.buffer.Bytes())
|
||||
if brs.buf != nil && len(tt.expectedBuffer) > 0 {
|
||||
assert.Equal(t, tt.expectedBuffer, brs.buf.Bytes())
|
||||
} else {
|
||||
assert.Nil(t, tt.expectedBuffer)
|
||||
}
|
||||
|
@ -240,7 +230,7 @@ func TestBufferedReaderSeekerSeek(t *testing.T) {
|
|||
reader: bytes.NewBufferString("test data"),
|
||||
offset: 20,
|
||||
whence: io.SeekEnd,
|
||||
expectedPos: 9,
|
||||
expectedPos: 29,
|
||||
expectedErr: false,
|
||||
expectedRead: []byte{},
|
||||
},
|
||||
|
|
Loading…
Reference in a new issue