diff --git a/pkg/handlers/ar.go b/pkg/handlers/ar.go index 1b8a5a72a..8abbfb288 100644 --- a/pkg/handlers/ar.go +++ b/pkg/handlers/ar.go @@ -83,9 +83,9 @@ func (h *arHandler) processARFiles(ctx logContext.Context, reader *deb.Ar, archi fileSize := arEntry.Size fileCtx := logContext.WithValues(ctx, "filename", arEntry.Name, "size", fileSize) - rdr, err := newSizedMimeTypeReader(arEntry.Data, fileSize) - if err != nil { - return fmt.Errorf("error creating mime-type reader: %w", err) + rdr, err := newSizedReader(arEntry.Data, fileSize) + if err := handleReaderError(fileCtx, err); err != nil { + return err } if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil { diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index 91c49a7d5..0e84f4138 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -89,7 +89,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f if reader.format == nil { if depth > 0 { - mtr, err := newSizedMimetypeReaderFromFileReader(reader) + mtr, err := newSizedReaderFromFileReader(reader) if err != nil { return fmt.Errorf("error reading MIME type: %w", err) } @@ -108,12 +108,8 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f defer compReader.Close() rdr, err := newFileReader(compReader) - if err != nil { - if errors.Is(err, ErrEmptyReader) { - ctx.Logger().V(5).Info("empty reader, skipping file") - return nil - } - return fmt.Errorf("error creating custom reader: %w", err) + if err := handleReaderError(ctx, err); err != nil { + return err } return h.openArchive(ctx, depth+1, rdr, archiveChan) @@ -162,12 +158,6 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context. return nil } - if common.SkipFile(file.Name()) || common.IsBinary(file.Name()) { - lCtx.Logger().V(5).Info("skipping file") - h.metrics.incFilesSkipped() - return nil - } - f, err := file.Open() if err != nil { return fmt.Errorf("error opening file %s: %w", file.Name(), err) @@ -191,12 +181,8 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context. }() rdr, err := newFileReader(f) - if err != nil { - if errors.Is(err, ErrEmptyReader) { - lCtx.Logger().V(5).Info("empty reader, skipping file") - return nil - } - return fmt.Errorf("error creating custom reader: %w", err) + if err := handleReaderError(lCtx, err); err != nil { + return err } h.metrics.incFilesProcessed() diff --git a/pkg/handlers/default.go b/pkg/handlers/default.go index c30fd2919..b518aecec 100644 --- a/pkg/handlers/default.go +++ b/pkg/handlers/default.go @@ -44,7 +44,7 @@ func (h *defaultHandler) HandleFile(ctx logContext.Context, input fileReader) (c h.metrics.incFilesProcessed() }() - mtr, err := newSizedMimetypeReaderFromFileReader(input) + mtr, err := newSizedReaderFromFileReader(input) if err != nil { ctx.Logger().Error(err, "error reading MIME type") return @@ -76,15 +76,7 @@ func (h *defaultHandler) measureLatencyAndHandleErrors(start time.Time, err erro // on the type, particularly for binary files. It manages reading file chunks and writing them to the archive channel, // effectively collecting the final bytes for further processing. This function is a key component in ensuring that all // file content, regardless of being an archive or not, is handled appropriately. -func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader sizedMimeTypeReader, archiveChan chan []byte) error { - mimeExt := reader.mimeExt - - if common.SkipFile(mimeExt) || common.IsBinary(mimeExt) { - ctx.Logger().V(5).Info("skipping file", "ext", mimeExt) - h.metrics.incFilesSkipped() - return nil - } - +func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader sizedReader, archiveChan chan []byte) error { chunkReader := sources.NewChunkReader(sources.WithFileSize(int(reader.size))) for data := range chunkReader(ctx, reader) { if err := data.Error(); err != nil { diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 2d68c5738..c85c47d83 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -9,6 +9,7 @@ import ( "github.com/gabriel-vasile/mimetype" "github.com/mholt/archiver/v4" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/iobuf" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" @@ -37,21 +38,20 @@ type fileReader struct { *iobuf.BufferedReadSeeker } -var ErrEmptyReader = errors.New("reader is empty") +var ( + errEmptyReader = errors.New("reader is empty") + errUnsupportedMIME = errors.New("unsupported MIME type") +) -// sizedMimeTypeReader wraps an io.Reader with MIME type information and the size of the content. -// This type is used to pass content through the processing pipeline -// while carrying its detected MIME type and size, avoiding redundant type detection and size calculation. -type sizedMimeTypeReader struct { - mimeExt string // Extension derived from the MIME type (e.g., ".zip", ".tar", etc.) - mimeType mimeType // MIME type (e.g., "application/zip", "application/x-tar", etc.) - size int64 +// sizedReader wraps an io.Reader with the size of the content. +type sizedReader struct { + size int64 io.Reader } -// newSizedMimeTypeReaderFromFileReader creates a new sizedMimeTypeReader from a fileReader. -// This function extracts the MIME type and size from the fileReader, and returns a new sizedMimeTypeReader. -func newSizedMimetypeReaderFromFileReader(r fileReader) (sizedMimeTypeReader, error) { +// newSizedReaderFromFileReader creates a new sizedReader from a fileReader. +// This function extracts the size from the fileReader, and returns a new sizedReader. +func newSizedReaderFromFileReader(r fileReader) (sizedReader, error) { originalBufferingState := r.IsBufferingEnabled() if !originalBufferingState { r.EnableBuffering() @@ -64,36 +64,67 @@ func newSizedMimetypeReaderFromFileReader(r fileReader) (sizedMimeTypeReader, er size, err := r.Size() if err != nil { - return sizedMimeTypeReader{}, fmt.Errorf("error getting file size: %w", err) + return sizedReader{}, fmt.Errorf("error getting file size: %w", err) } - return sizedMimeTypeReader{ - mimeExt: r.mime.Extension(), - mimeType: mimeType(r.mime.String()), - size: size, - Reader: r.BufferedReadSeeker, - }, nil + return sizedReader{size: size, Reader: r.BufferedReadSeeker}, nil } -// newSizedMimeTypeReader creates a new sizedMimeTypeReader from an io.Reader. -// It uses a bufio.Reader to perform MIME type detection on the input reader -// without consuming it, by peeking into the first 512 bytes of the input. -// This encapsulates both the original reader and the detected MIME type information. +// newSizedReader creates a new sizedReader from an io.Reader. // This function is particularly useful for specialized archive handlers // that need to pass extracted content to the default handler without modifying the original reader. -func newSizedMimeTypeReader(r io.Reader, size int64) (sizedMimeTypeReader, error) { +func newSizedReader(r io.Reader, size int64) (sizedReader, error) { + if r == nil { + return sizedReader{}, errors.New("reader is nil") + } + if size == 0 { + return sizedReader{}, errEmptyReader + } + + bufReader, _, err := determineMIMEType(r) + if err != nil { + return sizedReader{}, err + } + + return sizedReader{size: size, Reader: bufReader}, nil +} + +func determineMIMEType(r io.Reader) (io.Reader, *mimetype.MIME, error) { const defaultMinBufferSize = 3072 bufReader := bufio.NewReaderSize(r, defaultMinBufferSize) // A buffer of 512 bytes is used since many file formats store their magic numbers within the first 512 bytes. // If fewer bytes are read, MIME type detection may still succeed. buffer, err := bufReader.Peek(defaultMinBufferSize) if err != nil && !errors.Is(err, io.EOF) { - return sizedMimeTypeReader{}, fmt.Errorf("unable to read file for MIME type detection: %w", err) + return nil, nil, fmt.Errorf("unable to read file for MIME type detection: %w", err) + } + + if len(buffer) == 0 { + return nil, nil, errEmptyReader } mime := mimetype.Detect(buffer) + if common.SkipFile(mime.String()) || common.IsBinary(mime.String()) { + return nil, mime, errUnsupportedMIME + } - return sizedMimeTypeReader{mimeExt: mime.Extension(), mimeType: mimeType(mime.String()), size: size, Reader: bufReader}, nil + return bufReader, mime, nil +} + +func handleReaderError(ctx logContext.Context, err error) error { + if err == nil { + return nil + } + + if errors.Is(err, errEmptyReader) { + ctx.Logger().V(5).Info("empty reader, skipping file") + return nil + } else if errors.Is(err, errUnsupportedMIME) { + ctx.Logger().V(5).Info("skipping file") + return nil + } + + return fmt.Errorf("error creating reader: %w", err) } // newFileReader creates a fileReader from an io.Reader, optionally using BufferedFileWriter for certain formats. @@ -106,9 +137,9 @@ func newFileReader(r io.Reader) (fileReader, error) { // This optimization ensures we don't continue writing to the buffer after the initial reads. defer fReader.DisableBuffering() - mime, err := mimetype.DetectReader(fReader) + _, mime, err := determineMIMEType(fReader) if err != nil { - return fReader, fmt.Errorf("unable to detect MIME type: %w", err) + return fReader, err } fReader.mime = mime @@ -295,12 +326,12 @@ func HandleFile( } rdr, err := newFileReader(reader) - if err != nil { - if errors.Is(err, ErrEmptyReader) { - ctx.Logger().V(5).Info("empty reader, skipping file") - return nil - } - return fmt.Errorf("error creating custom reader: %w", err) + if errors.Is(err, errEmptyReader) { + ctx.Logger().V(5).Info("empty reader, skipping file") + return nil + } else if errors.Is(err, errUnsupportedMIME) { + ctx.Logger().V(5).Info("skipping file") + return nil } mimeT := mimeType(rdr.mime.String()) diff --git a/pkg/handlers/handlers_test.go b/pkg/handlers/handlers_test.go index 05ea4bad9..65545b70b 100644 --- a/pkg/handlers/handlers_test.go +++ b/pkg/handlers/handlers_test.go @@ -269,6 +269,26 @@ func TestExtractTarContentWithEmptyFile(t *testing.T) { assert.Equal(t, wantCount, count) } +func TestExtractEmptyFile(t *testing.T) { + file, err := os.Open("testdata/empty.txt") + assert.Nil(t, err) + defer file.Close() + + chunkCh := make(chan *sources.Chunk, 1) + go func() { + defer close(chunkCh) + err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + assert.NoError(t, err) + }() + + wantCount := 0 + count := 0 + for range chunkCh { + count++ + } + assert.Equal(t, wantCount, count) +} + func TestHandleTar(t *testing.T) { file, err := os.Open("testdata/test.tar") assert.Nil(t, err) @@ -353,7 +373,7 @@ func TestNewSizedMimetypeReaderFromFileReader(t *testing.T) { BufferedReadSeeker: brs, } - result, err := newSizedMimetypeReaderFromFileReader(fr) + result, err := newSizedReaderFromFileReader(fr) assert.NoError(t, err) assert.Equal(t, tt.expectedSize, result.size) diff --git a/pkg/handlers/rpm.go b/pkg/handlers/rpm.go index 2d911d29d..54a28c882 100644 --- a/pkg/handlers/rpm.go +++ b/pkg/handlers/rpm.go @@ -90,9 +90,9 @@ func (h *rpmHandler) processRPMFiles(ctx logContext.Context, reader rpmutils.Pay fileSize := fileInfo.Size() fileCtx := logContext.WithValues(ctx, "filename", fileInfo.Name, "size", fileSize) - rdr, err := newSizedMimeTypeReader(reader, fileSize) - if err != nil { - return fmt.Errorf("error creating mime-type reader: %w", err) + rdr, err := newSizedReader(reader, fileSize) + if err := handleReaderError(fileCtx, err); err != nil { + return err } if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil { diff --git a/pkg/handlers/testdata/empty.txt b/pkg/handlers/testdata/empty.txt new file mode 100644 index 000000000..e69de29bb