mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[feat] - buffered file reader (#2731)
* Update write method in contentWriter interface * fix lint * Add a buffered file reader * update comments * update comment * add compile type checks * fix * fix test * inline if * magic * update comment
This commit is contained in:
parent
13bd783d2d
commit
46d4ae1334
5 changed files with 322 additions and 7 deletions
85
pkg/readers/bufferedfilereader.go
Normal file
85
pkg/readers/bufferedfilereader.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package readers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer"
|
||||
)
|
||||
|
||||
// Compile time check to ensure that bufferedFileReader implements io.ReaderAt.
|
||||
var _ io.ReaderAt = (*bufferedFileReader)(nil)
|
||||
|
||||
// bufferedFileReader provides random access read, seek, and close capabilities on top of the BufferedFileWriter.
|
||||
// It combines the functionality of BufferedFileWriter for buffered writing, io.ReadSeekCloser for
|
||||
// random access reading and seeking.
|
||||
type bufferedFileReader struct {
|
||||
bufWriter *bufferedfilewriter.BufferedFileWriter
|
||||
reader io.ReadSeekCloser
|
||||
}
|
||||
|
||||
// NewBufferedFileReader initializes a bufferedFileReader from an io.Reader by using
|
||||
// the BufferedFileWriter's functionality to read and store data, then setting up a io.ReadSeekCloser
|
||||
// for random access reading and seeking.
|
||||
// Close should be called when the bufferedFileReader is no longer needed.
|
||||
// It returns the initialized bufferedFileReader and any error encountered during the process.
|
||||
func NewBufferedFileReader(r io.Reader) (*bufferedFileReader, error) {
|
||||
writer, err := bufferedfilewriter.NewFromReader(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating bufferedFileReader: %w", err)
|
||||
}
|
||||
|
||||
// Ensure that the BufferedFileWriter is in read-only mode.
|
||||
if err := writer.CloseForWriting(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rdr, err := writer.ReadSeekCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &bufferedFileReader{writer, rdr}, nil
|
||||
}
|
||||
|
||||
// Close the bufferedFileReader.
|
||||
// It should be called when the bufferedFileReader is no longer needed.
|
||||
// Note that closing the bufferedFileReader does not affect the underlying bytes.Reader,
|
||||
// which can still be used for reading, seeking, and reading at specific positions.
|
||||
// Close is a no-op for the bytes.Reader.
|
||||
func (b *bufferedFileReader) Close() error {
|
||||
return b.reader.Close()
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p from the underlying reader.
|
||||
// It returns the number of bytes read and any error encountered.
|
||||
// If the reader reaches the end of the available data, Read returns 0, io.EOF.
|
||||
// It implements the io.Reader interface.
|
||||
func (b *bufferedFileReader) Read(p []byte) (int, error) {
|
||||
return b.reader.Read(p)
|
||||
}
|
||||
|
||||
// Seek sets the offset for the next Read operation on the underlying reader.
|
||||
// The offset is interpreted according to the whence parameter:
|
||||
// - io.SeekStart means relative to the start of the file
|
||||
// - io.SeekCurrent means relative to the current offset
|
||||
// - io.SeekEnd means relative to the end of the file
|
||||
//
|
||||
// Seek returns the new offset and any error encountered.
|
||||
// It implements the io.Seeker interface.
|
||||
func (b *bufferedFileReader) Seek(offset int64, whence int) (int64, error) {
|
||||
return b.reader.Seek(offset, whence)
|
||||
}
|
||||
|
||||
// ReadAt reads len(p) bytes from the underlying io.ReadSeekCloser starting at byte offset off.
|
||||
// It returns the number of bytes read and any error encountered.
|
||||
// If the io.ReadSeekCloser reaches the end of the available data before len(p) bytes are read,
|
||||
// ReadAt returns the number of bytes read and io.EOF.
|
||||
// It implements the io.ReaderAt interface.
|
||||
func (b *bufferedFileReader) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
_, err = b.reader.Seek(off, io.SeekStart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b.reader.Read(p)
|
||||
}
|
113
pkg/readers/bufferedfilereader_test.go
Normal file
113
pkg/readers/bufferedfilereader_test.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package readers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBufferedFileReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
|
||||
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(data))
|
||||
assert.NoError(t, err)
|
||||
defer bufferReadSeekCloser.Close()
|
||||
|
||||
// Test Read.
|
||||
buffer := make([]byte, len(data))
|
||||
n, err := bufferReadSeekCloser.Read(buffer)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data), n)
|
||||
assert.Equal(t, data, buffer)
|
||||
|
||||
// Test Seek.
|
||||
offset := 7
|
||||
seekPos, err := bufferReadSeekCloser.Seek(int64(offset), io.SeekStart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(offset), seekPos)
|
||||
|
||||
// Test ReadAt.
|
||||
buffer = make([]byte, len(data)-offset)
|
||||
n, err = bufferReadSeekCloser.ReadAt(buffer, int64(offset))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-offset, n)
|
||||
assert.Equal(t, data[offset:], buffer)
|
||||
|
||||
// Test Close.
|
||||
err = bufferReadSeekCloser.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBufferedFileReaderClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := []byte("Hello, World!")
|
||||
|
||||
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(data))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = bufferReadSeekCloser.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read after closing.
|
||||
buffer := make([]byte, len(data))
|
||||
n, err := bufferReadSeekCloser.Read(buffer)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data), n)
|
||||
assert.Equal(t, data, buffer)
|
||||
|
||||
// Seek after closing.
|
||||
offset := 7
|
||||
seekPos, err := bufferReadSeekCloser.Seek(int64(offset), io.SeekStart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(offset), seekPos)
|
||||
|
||||
// ReadAt after closing.
|
||||
buffer = make([]byte, len(data)-offset)
|
||||
n, err = bufferReadSeekCloser.ReadAt(buffer, int64(offset))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(data)-offset, n)
|
||||
assert.Equal(t, data[offset:], buffer)
|
||||
}
|
||||
|
||||
func TestBufferedFileReaderReadFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a large byte slice to simulate data exceeding the threshold.
|
||||
largeData := make([]byte, 1024*1024) // 1 MB
|
||||
for i := range largeData {
|
||||
largeData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
bufferReadSeekCloser, err := NewBufferedFileReader(bytes.NewReader(largeData))
|
||||
assert.NoError(t, err)
|
||||
defer bufferReadSeekCloser.Close()
|
||||
|
||||
// Test Read.
|
||||
buffer := make([]byte, len(largeData))
|
||||
n, err := bufferReadSeekCloser.Read(buffer)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(largeData), n)
|
||||
assert.Equal(t, largeData, buffer)
|
||||
|
||||
// Test Seek.
|
||||
offset := 512 * 1024 // 512 KB
|
||||
seekPos, err := bufferReadSeekCloser.Seek(int64(offset), io.SeekStart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(offset), seekPos)
|
||||
|
||||
// Test ReadAt.
|
||||
buffer = make([]byte, len(largeData)-offset)
|
||||
n, err = bufferReadSeekCloser.ReadAt(buffer, int64(offset))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(largeData)-offset, n)
|
||||
assert.Equal(t, largeData[offset:], buffer)
|
||||
|
||||
// Test Close.
|
||||
err = bufferReadSeekCloser.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -4,6 +4,7 @@ package buffer
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
@ -140,6 +141,9 @@ func (b *Buffer) Write(data []byte) (int, error) {
|
|||
return b.Buffer.Write(data)
|
||||
}
|
||||
|
||||
// Compile time check to make sure readCloser implements io.ReadSeekCloser.
|
||||
var _ io.ReadSeekCloser = (*readCloser)(nil)
|
||||
|
||||
// readCloser is a custom implementation of io.ReadCloser. It wraps a bytes.Reader
|
||||
// for reading data from an in-memory buffer and includes an onClose callback.
|
||||
// The onClose callback is used to return the buffer to the pool, ensuring buffer re-usability.
|
||||
|
|
|
@ -80,6 +80,16 @@ func New(opts ...Option) *BufferedFileWriter {
|
|||
return w
|
||||
}
|
||||
|
||||
// NewFromReader creates a new instance of BufferedFileWriter and writes the content from the provided reader to the writer.
|
||||
func NewFromReader(r io.Reader, opts ...Option) (*BufferedFileWriter, error) {
|
||||
writer := New(opts...)
|
||||
if _, err := io.Copy(writer, r); err != nil {
|
||||
return nil, fmt.Errorf("error writing to buffered file writer: %w", err)
|
||||
}
|
||||
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
// Len returns the number of bytes written to the buffer or file.
|
||||
func (w *BufferedFileWriter) Len() int { return int(w.size) }
|
||||
|
||||
|
@ -187,14 +197,19 @@ func (w *BufferedFileWriter) CloseForWriting() error {
|
|||
return w.file.Close()
|
||||
}
|
||||
|
||||
// ReadCloser returns an io.ReadCloser to read the written content. It provides a reader
|
||||
// based on the current storage medium of the data (in-memory buffer or file).
|
||||
// If the total content size exceeds the predefined threshold, it is stored in a temporary file and a file
|
||||
// reader is returned. For in-memory data, it returns a custom reader that handles returning
|
||||
// ReadCloser returns an io.ReadCloser to read the written content.
|
||||
// If the content is stored in a file, it opens the file and returns a file reader.
|
||||
// If the content is stored in memory, it returns a custom reader that handles returning the buffer to the pool.
|
||||
// The caller should call Close() on the returned io.Reader when done to ensure resources are properly released.
|
||||
// This method can only be used when the BufferedFileWriter is in read-only mode.
|
||||
func (w *BufferedFileWriter) ReadCloser() (io.ReadCloser, error) { return w.ReadSeekCloser() }
|
||||
|
||||
// ReadSeekCloser returns an io.ReadSeekCloser to read the written content.
|
||||
// If the content is stored in a file, it opens the file and returns a file reader.
|
||||
// If the content is stored in memory, it returns a custom reader that allows seeking and handles returning
|
||||
// the buffer to the pool.
|
||||
// The caller should call Close() on the returned io.Reader when done to ensure files are cleaned up.
|
||||
// It can only be used when the BufferedFileWriter is in read-only mode.
|
||||
func (w *BufferedFileWriter) ReadCloser() (io.ReadCloser, error) {
|
||||
// This method can only be used when the BufferedFileWriter is in read-only mode.
|
||||
func (w *BufferedFileWriter) ReadSeekCloser() (io.ReadSeekCloser, error) {
|
||||
if w.state != readOnly {
|
||||
return nil, fmt.Errorf("BufferedFileWriter must be in read-only mode to read")
|
||||
}
|
||||
|
|
|
@ -2,7 +2,11 @@ package bufferedfilewriter
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -494,6 +498,100 @@ func BenchmarkBufferedFileWriterWriteSmall(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
// Create a custom reader that can simulate errors.
|
||||
type errorReader struct{}
|
||||
|
||||
func (errorReader) Read([]byte) (n int, err error) { return 0, fmt.Errorf("error reading") }
|
||||
|
||||
func TestNewFromReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
wantErr bool
|
||||
wantData string
|
||||
}{
|
||||
{
|
||||
name: "Success case",
|
||||
reader: strings.NewReader("hello world"),
|
||||
wantData: "hello world",
|
||||
},
|
||||
{
|
||||
name: "Empty reader",
|
||||
reader: strings.NewReader(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Error reader",
|
||||
reader: errorReader{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bufWriter, err := NewFromReader(tc.reader)
|
||||
if err != nil && tc.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bufWriter)
|
||||
|
||||
err = bufWriter.CloseForWriting()
|
||||
assert.NoError(t, err)
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
rdr, err := bufWriter.ReadCloser()
|
||||
if err != nil && tc.wantErr {
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
_, err = b.ReadFrom(rdr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.wantData, b.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFromReaderThresholdExceeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a large data buffer that exceeds the threshold.
|
||||
largeData := make([]byte, 1024*1024) // 1 MB
|
||||
_, err := rand.Read(largeData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a BufferedFileWriter with a smaller threshold.
|
||||
threshold := uint64(1024) // 1 KB
|
||||
bufWriter, err := NewFromReader(bytes.NewReader(largeData), WithThreshold(threshold))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = bufWriter.CloseForWriting()
|
||||
assert.NoError(t, err)
|
||||
|
||||
rdr, err := bufWriter.ReadCloser()
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
// Verify that the data was written to a file.
|
||||
assert.NotEmpty(t, bufWriter.filename)
|
||||
assert.NotNil(t, bufWriter.file)
|
||||
|
||||
// Read the data from the BufferedFileWriter.
|
||||
readData, err := io.ReadAll(rdr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeData, readData)
|
||||
|
||||
// Verify the size of the data written.
|
||||
assert.Equal(t, uint64(len(largeData)), bufWriter.size)
|
||||
}
|
||||
|
||||
func TestBufferWriterCloseForWritingWithFile(t *testing.T) {
|
||||
bufPool := buffer.NewBufferPool()
|
||||
|
||||
|
|
Loading…
Reference in a new issue