mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[feat] - Add ReadFrom method to BufferedFileWriter (#2759)
* Update write method in contentWriter interface * fix lint * Add a buffered file reader * update comments * update comment * add compile type checks * fix * fix test * inline if * Add ReadFrom method to the BufferedFileWriter * update test * fix test * update benchmark
This commit is contained in:
parent
46d4ae1334
commit
7e47b96631
2 changed files with 179 additions and 0 deletions
|
@ -4,6 +4,7 @@ package bufferedfilewriter
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -176,6 +177,48 @@ func (w *BufferedFileWriter) Write(data []byte) (int, error) {
|
|||
return n, nil
|
||||
}
|
||||
|
||||
// ReadFrom reads data from the provided reader and writes it to the buffer or file, depending on the size.
|
||||
// This method satisfies the io.ReaderFrom interface, allowing it to be used with standard library functions
|
||||
// like io.Copy and io.CopyBuffer.
|
||||
//
|
||||
// By implementing this method, BufferedFileWriter can leverage optimized data transfer mechanisms provided
|
||||
// by the standard library. For example, when using io.Copy with a BufferedFileWriter, the copy operation
|
||||
// will be delegated to the ReadFrom method, avoiding the potentially non-optimized default approach.
|
||||
//
|
||||
// This is particularly useful when creating a new BufferedFileWriter from an io.Reader using the NewFromReader
|
||||
// function. By leveraging the ReadFrom method, data can be efficiently transferred from the reader to
|
||||
// the BufferedFileWriter.
|
||||
func (w *BufferedFileWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
if w.state != writeOnly {
|
||||
return 0, fmt.Errorf("BufferedFileWriter must be in write-only mode to write")
|
||||
}
|
||||
|
||||
var totalBytesRead int64
|
||||
const bufferSize = 1 << 16 // 64KB
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return totalBytesRead, err
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
written, err := w.Write(buf[:n])
|
||||
if err != nil {
|
||||
return totalBytesRead, err
|
||||
}
|
||||
totalBytesRead += int64(written)
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytesRead, nil
|
||||
}
|
||||
|
||||
// CloseForWriting flushes any remaining data in the buffer to the file, closes the file if created,
|
||||
// and transitions the BufferedFileWriter to read-only mode.
|
||||
func (w *BufferedFileWriter) CloseForWriting() error {
|
||||
|
|
|
@ -620,3 +620,139 @@ func TestBufferWriterCloseForWritingWithFile(t *testing.T) {
|
|||
assert.Same(t, buf, bufFromPool, "Buffer should be returned to the pool")
|
||||
bufPool.Put(bufFromPool)
|
||||
}
|
||||
|
||||
func TestBufferedFileWriter_ReadFrom(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedOutput string
|
||||
expectedSize int64
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: "",
|
||||
expectedOutput: "",
|
||||
expectedSize: 0,
|
||||
},
|
||||
{
|
||||
name: "Small input",
|
||||
input: "Hello, World!",
|
||||
expectedOutput: "Hello, World!",
|
||||
expectedSize: 13,
|
||||
},
|
||||
{
|
||||
name: "Large input",
|
||||
input: string(make([]byte, 1<<20)), // 1MB input
|
||||
expectedOutput: string(make([]byte, 1<<20)),
|
||||
expectedSize: 1 << 20,
|
||||
},
|
||||
{
|
||||
name: "Input greater than threshold",
|
||||
input: string(make([]byte, defaultThreshold+1)),
|
||||
expectedOutput: string(make([]byte, defaultThreshold+1)),
|
||||
expectedSize: defaultThreshold + 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
writer := New()
|
||||
reader := bytes.NewReader([]byte(tc.input))
|
||||
size, err := writer.ReadFrom(reader)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = writer.CloseForWriting()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedSize, size)
|
||||
if size == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rc, err := writer.ReadCloser()
|
||||
assert.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
_, err = io.Copy(&result, rc)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedOutput, result.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// simpleReader wraps a string, allowing it to be read as an io.Reader without implementing io.WriterTo.
|
||||
type simpleReader struct {
|
||||
data []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
func newSimpleReader(s string) *simpleReader { return &simpleReader{data: []byte(s)} }
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (sr *simpleReader) Read(p []byte) (n int, err error) {
|
||||
if sr.offset >= len(sr.data) {
|
||||
return 0, io.EOF // no more data to read
|
||||
}
|
||||
n = copy(p, sr.data[sr.offset:]) // copy data to p
|
||||
sr.offset += n // move offset for next read
|
||||
return
|
||||
}
|
||||
|
||||
func TestNewFromReaderThresholdExceededSimpleReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a large data buffer that exceeds the threshold.
|
||||
largeData := strings.Repeat("a", 1024*1024) // 1 MB
|
||||
|
||||
// Create a BufferedFileWriter with a smaller threshold.
|
||||
threshold := uint64(1024) // 1 KB
|
||||
bufWriter, err := NewFromReader(newSimpleReader(largeData), WithThreshold(threshold))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = bufWriter.CloseForWriting()
|
||||
assert.NoError(t, err)
|
||||
|
||||
rdr, err := bufWriter.ReadCloser()
|
||||
assert.NoError(t, err)
|
||||
defer rdr.Close()
|
||||
|
||||
// Verify that the data was written to a file.
|
||||
assert.NotEmpty(t, bufWriter.filename)
|
||||
assert.NotNil(t, bufWriter.file)
|
||||
|
||||
// Read the data from the BufferedFileWriter.
|
||||
readData, err := io.ReadAll(rdr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeData, string(readData))
|
||||
|
||||
// Verify the size of the data written.
|
||||
assert.Equal(t, uint64(len(largeData)), bufWriter.size)
|
||||
}
|
||||
|
||||
func BenchmarkNewFromReader(b *testing.B) {
|
||||
largeData := strings.Repeat("a", 1024*1024) // 1 MB
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := newSimpleReader(largeData)
|
||||
|
||||
b.StartTimer()
|
||||
bufWriter, err := NewFromReader(reader)
|
||||
assert.NoError(b, err)
|
||||
b.StopTimer()
|
||||
|
||||
err = bufWriter.CloseForWriting()
|
||||
assert.NoError(b, err)
|
||||
|
||||
rdr, err := bufWriter.ReadCloser()
|
||||
assert.NoError(b, err)
|
||||
rdr.Close()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue