Move mimetype detection into newFileReader

This commit is contained in:
Ahrav Dutta 2024-07-21 15:33:01 -07:00
parent 7d2ea92cb2
commit 63229d52e3
7 changed files with 98 additions and 69 deletions

View file

@ -83,9 +83,9 @@ func (h *arHandler) processARFiles(ctx logContext.Context, reader *deb.Ar, archi
fileSize := arEntry.Size
fileCtx := logContext.WithValues(ctx, "filename", arEntry.Name, "size", fileSize)
rdr, err := newSizedMimeTypeReader(arEntry.Data, fileSize)
if err != nil {
return fmt.Errorf("error creating mime-type reader: %w", err)
rdr, err := newSizedReader(arEntry.Data, fileSize)
if err := handleReaderError(fileCtx, err); err != nil {
return err
}
if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil {

View file

@ -89,7 +89,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
if reader.format == nil {
if depth > 0 {
mtr, err := newSizedMimetypeReaderFromFileReader(reader)
mtr, err := newSizedReaderFromFileReader(reader)
if err != nil {
return fmt.Errorf("error reading MIME type: %w", err)
}
@ -108,12 +108,8 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f
defer compReader.Close()
rdr, err := newFileReader(compReader)
if err != nil {
if errors.Is(err, ErrEmptyReader) {
ctx.Logger().V(5).Info("empty reader, skipping file")
return nil
}
return fmt.Errorf("error creating custom reader: %w", err)
if err := handleReaderError(ctx, err); err != nil {
return err
}
return h.openArchive(ctx, depth+1, rdr, archiveChan)
@ -162,12 +158,6 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
return nil
}
if common.SkipFile(file.Name()) || common.IsBinary(file.Name()) {
lCtx.Logger().V(5).Info("skipping file")
h.metrics.incFilesSkipped()
return nil
}
f, err := file.Open()
if err != nil {
return fmt.Errorf("error opening file %s: %w", file.Name(), err)
@ -191,12 +181,8 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.
}()
rdr, err := newFileReader(f)
if err != nil {
if errors.Is(err, ErrEmptyReader) {
lCtx.Logger().V(5).Info("empty reader, skipping file")
return nil
}
return fmt.Errorf("error creating custom reader: %w", err)
if err := handleReaderError(lCtx, err); err != nil {
return err
}
h.metrics.incFilesProcessed()

View file

@ -44,7 +44,7 @@ func (h *defaultHandler) HandleFile(ctx logContext.Context, input fileReader) (c
h.metrics.incFilesProcessed()
}()
mtr, err := newSizedMimetypeReaderFromFileReader(input)
mtr, err := newSizedReaderFromFileReader(input)
if err != nil {
ctx.Logger().Error(err, "error reading MIME type")
return
@ -76,15 +76,7 @@ func (h *defaultHandler) measureLatencyAndHandleErrors(start time.Time, err erro
// on the type, particularly for binary files. It manages reading file chunks and writing them to the archive channel,
// effectively collecting the final bytes for further processing. This function is a key component in ensuring that all
// file content, regardless of being an archive or not, is handled appropriately.
func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader sizedMimeTypeReader, archiveChan chan []byte) error {
mimeExt := reader.mimeExt
if common.SkipFile(mimeExt) || common.IsBinary(mimeExt) {
ctx.Logger().V(5).Info("skipping file", "ext", mimeExt)
h.metrics.incFilesSkipped()
return nil
}
func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader sizedReader, archiveChan chan []byte) error {
chunkReader := sources.NewChunkReader(sources.WithFileSize(int(reader.size)))
for data := range chunkReader(ctx, reader) {
if err := data.Error(); err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/gabriel-vasile/mimetype"
"github.com/mholt/archiver/v4"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/iobuf"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
@ -37,21 +38,20 @@ type fileReader struct {
*iobuf.BufferedReadSeeker
}
var ErrEmptyReader = errors.New("reader is empty")
var (
errEmptyReader = errors.New("reader is empty")
errUnsupportedMIME = errors.New("unsupported MIME type")
)
// sizedMimeTypeReader wraps an io.Reader with MIME type information and the size of the content.
// This type is used to pass content through the processing pipeline
// while carrying its detected MIME type and size, avoiding redundant type detection and size calculation.
type sizedMimeTypeReader struct {
mimeExt string // Extension derived from the MIME type (e.g., ".zip", ".tar", etc.)
mimeType mimeType // MIME type (e.g., "application/zip", "application/x-tar", etc.)
size int64
// sizedReader wraps an io.Reader with the size of the content.
type sizedReader struct {
size int64
io.Reader
}
// newSizedMimeTypeReaderFromFileReader creates a new sizedMimeTypeReader from a fileReader.
// This function extracts the MIME type and size from the fileReader, and returns a new sizedMimeTypeReader.
func newSizedMimetypeReaderFromFileReader(r fileReader) (sizedMimeTypeReader, error) {
// newSizedReaderFromFileReader creates a new sizedReader from a fileReader.
// This function extracts the size from the fileReader, and returns a new sizedReader.
func newSizedReaderFromFileReader(r fileReader) (sizedReader, error) {
originalBufferingState := r.IsBufferingEnabled()
if !originalBufferingState {
r.EnableBuffering()
@ -64,36 +64,67 @@ func newSizedMimetypeReaderFromFileReader(r fileReader) (sizedMimeTypeReader, er
size, err := r.Size()
if err != nil {
return sizedMimeTypeReader{}, fmt.Errorf("error getting file size: %w", err)
return sizedReader{}, fmt.Errorf("error getting file size: %w", err)
}
return sizedMimeTypeReader{
mimeExt: r.mime.Extension(),
mimeType: mimeType(r.mime.String()),
size: size,
Reader: r.BufferedReadSeeker,
}, nil
return sizedReader{size: size, Reader: r.BufferedReadSeeker}, nil
}
// newSizedMimeTypeReader creates a new sizedMimeTypeReader from an io.Reader.
// It uses a bufio.Reader to perform MIME type detection on the input reader
// without consuming it, by peeking into the first 512 bytes of the input.
// This encapsulates both the original reader and the detected MIME type information.
// newSizedReader creates a new sizedReader from an io.Reader.
// This function is particularly useful for specialized archive handlers
// that need to pass extracted content to the default handler without modifying the original reader.
func newSizedMimeTypeReader(r io.Reader, size int64) (sizedMimeTypeReader, error) {
func newSizedReader(r io.Reader, size int64) (sizedReader, error) {
if r == nil {
return sizedReader{}, errors.New("reader is nil")
}
if size == 0 {
return sizedReader{}, errEmptyReader
}
bufReader, _, err := determineMIMEType(r)
if err != nil {
return sizedReader{}, err
}
return sizedReader{size: size, Reader: bufReader}, nil
}
func determineMIMEType(r io.Reader) (io.Reader, *mimetype.MIME, error) {
const defaultMinBufferSize = 3072
bufReader := bufio.NewReaderSize(r, defaultMinBufferSize)
// A buffer of 512 bytes is used since many file formats store their magic numbers within the first 512 bytes.
// If fewer bytes are read, MIME type detection may still succeed.
buffer, err := bufReader.Peek(defaultMinBufferSize)
if err != nil && !errors.Is(err, io.EOF) {
return sizedMimeTypeReader{}, fmt.Errorf("unable to read file for MIME type detection: %w", err)
return nil, nil, fmt.Errorf("unable to read file for MIME type detection: %w", err)
}
if len(buffer) == 0 {
return nil, nil, errEmptyReader
}
mime := mimetype.Detect(buffer)
if common.SkipFile(mime.String()) || common.IsBinary(mime.String()) {
return nil, mime, errUnsupportedMIME
}
return sizedMimeTypeReader{mimeExt: mime.Extension(), mimeType: mimeType(mime.String()), size: size, Reader: bufReader}, nil
return bufReader, mime, nil
}
func handleReaderError(ctx logContext.Context, err error) error {
if err == nil {
return nil
}
if errors.Is(err, errEmptyReader) {
ctx.Logger().V(5).Info("empty reader, skipping file")
return nil
} else if errors.Is(err, errUnsupportedMIME) {
ctx.Logger().V(5).Info("skipping file")
return nil
}
return fmt.Errorf("error creating reader: %w", err)
}
// newFileReader creates a fileReader from an io.Reader, optionally using BufferedFileWriter for certain formats.
@ -106,9 +137,9 @@ func newFileReader(r io.Reader) (fileReader, error) {
// This optimization ensures we don't continue writing to the buffer after the initial reads.
defer fReader.DisableBuffering()
mime, err := mimetype.DetectReader(fReader)
_, mime, err := determineMIMEType(fReader)
if err != nil {
return fReader, fmt.Errorf("unable to detect MIME type: %w", err)
return fReader, err
}
fReader.mime = mime
@ -295,12 +326,12 @@ func HandleFile(
}
rdr, err := newFileReader(reader)
if err != nil {
if errors.Is(err, ErrEmptyReader) {
ctx.Logger().V(5).Info("empty reader, skipping file")
return nil
}
return fmt.Errorf("error creating custom reader: %w", err)
if errors.Is(err, errEmptyReader) {
ctx.Logger().V(5).Info("empty reader, skipping file")
return nil
} else if errors.Is(err, errUnsupportedMIME) {
ctx.Logger().V(5).Info("skipping file")
return nil
}
mimeT := mimeType(rdr.mime.String())

View file

@ -269,6 +269,26 @@ func TestExtractTarContentWithEmptyFile(t *testing.T) {
assert.Equal(t, wantCount, count)
}
func TestExtractEmptyFile(t *testing.T) {
file, err := os.Open("testdata/empty.txt")
assert.Nil(t, err)
defer file.Close()
chunkCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunkCh)
err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh})
assert.NoError(t, err)
}()
wantCount := 0
count := 0
for range chunkCh {
count++
}
assert.Equal(t, wantCount, count)
}
func TestHandleTar(t *testing.T) {
file, err := os.Open("testdata/test.tar")
assert.Nil(t, err)
@ -353,7 +373,7 @@ func TestNewSizedMimetypeReaderFromFileReader(t *testing.T) {
BufferedReadSeeker: brs,
}
result, err := newSizedMimetypeReaderFromFileReader(fr)
result, err := newSizedReaderFromFileReader(fr)
assert.NoError(t, err)
assert.Equal(t, tt.expectedSize, result.size)

View file

@ -90,9 +90,9 @@ func (h *rpmHandler) processRPMFiles(ctx logContext.Context, reader rpmutils.Pay
fileSize := fileInfo.Size()
fileCtx := logContext.WithValues(ctx, "filename", fileInfo.Name, "size", fileSize)
rdr, err := newSizedMimeTypeReader(reader, fileSize)
if err != nil {
return fmt.Errorf("error creating mime-type reader: %w", err)
rdr, err := newSizedReader(reader, fileSize)
if err := handleReaderError(fileCtx, err); err != nil {
return err
}
if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil {

0
pkg/handlers/testdata/empty.txt vendored Normal file
View file