[chore] - Use custom context for archive handler of specialized archives (#1629)

* Use custom context for archive handler of specialized archives.

* fix arg.

* fix test.

* use re-reader.

* use re-reader.

* Update error and comments.

* Add better error handling.

* update.
This commit is contained in:
ahrav 2023-08-16 13:52:55 -07:00 committed by GitHub
parent 62d359eba4
commit e0db575d4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 21 deletions

View file

@ -246,8 +246,9 @@ func ensureToolsForMimeType(mimeType string) error {
// and processes it based on its extension, such as handling Debian (.deb) and RPM (.rpm) packages.
// It returns an io.Reader that can be used to read the processed content of the file,
// and an error if any issues occurred during processing.
// If the file is specialized, the returned boolean is true with no error.
// The caller is responsible for closing the returned reader.
func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.Reader, bool, error) {
func (a *Archive) HandleSpecialized(ctx logContext.Context, reader io.Reader) (io.Reader, bool, error) {
mimeType, reader, err := determineMimeType(reader)
if err != nil {
return nil, false, err
@ -279,7 +280,7 @@ func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.R
// It handles the extraction process by using the 'ar' command and manages temporary
// files and directories for the operation.
// The caller is responsible for closing the returned reader.
func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) {
func (a *Archive) extractDebContent(ctx logContext.Context, file io.Reader) (io.ReadCloser, error) {
if a.currentDepth >= maxDepth {
return nil, fmt.Errorf("max archive depth reached")
}
@ -297,7 +298,7 @@ func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.Rea
return nil, err
}
handler := func(ctx context.Context, env tempEnv, file string) (string, error) {
handler := func(ctx logContext.Context, env tempEnv, file string) (string, error) {
if strings.HasPrefix(file, "data.tar.") {
return file, nil
}
@ -317,7 +318,7 @@ func (a *Archive) extractDebContent(ctx context.Context, file io.Reader) (io.Rea
// It handles the extraction process by using the 'rpm2cpio' and 'cpio' commands and manages temporary
// files and directories for the operation.
// The caller is responsible for closing the returned reader.
func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.ReadCloser, error) {
func (a *Archive) extractRpmContent(ctx logContext.Context, file io.Reader) (io.ReadCloser, error) {
if a.currentDepth >= maxDepth {
return nil, fmt.Errorf("max archive depth reached")
}
@ -336,7 +337,7 @@ func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.Rea
return nil, err
}
handler := func(ctx context.Context, env tempEnv, file string) (string, error) {
handler := func(ctx logContext.Context, env tempEnv, file string) (string, error) {
if strings.HasSuffix(file, ".tar.gz") {
return file, nil
}
@ -351,7 +352,7 @@ func (a *Archive) extractRpmContent(ctx context.Context, file io.Reader) (io.Rea
return openDataArchive(tmpEnv.extractPath, dataArchiveName)
}
func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fileName string) (string, error) {
func (a *Archive) handleNestedFileMIME(ctx logContext.Context, tempEnv tempEnv, fileName string) (string, error) {
nestedFile, err := os.Open(filepath.Join(tempEnv.extractPath, fileName))
if err != nil {
return "", err
@ -360,7 +361,7 @@ func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fil
mimeType, reader, err := determineMimeType(nestedFile)
if err != nil {
return "", err
return "", fmt.Errorf("unable to determine MIME type of nested filename: %s, %w", nestedFile.Name(), err)
}
switch mimeType {
@ -373,7 +374,7 @@ func (a *Archive) handleNestedFileMIME(ctx context.Context, tempEnv tempEnv, fil
}
if err != nil {
return "", err
return "", fmt.Errorf("unable to extract file with MIME type %s: %w", mimeType, err)
}
return fileName, nil
@ -405,7 +406,7 @@ func determineMimeType(reader io.Reader) (string, io.Reader, error) {
// of the data archive it finds. This centralizes the logic for handling specialized files such as .deb and .rpm
// by using the appropriate handling function passed as an argument. This design allows for flexibility and reuse
// of this function across various extraction processes in the package.
func (a *Archive) handleExtractedFiles(ctx context.Context, env tempEnv, handleFile func(context.Context, tempEnv, string) (string, error)) (string, error) {
func (a *Archive) handleExtractedFiles(ctx logContext.Context, env tempEnv, handleFile func(logContext.Context, tempEnv, string) (string, error)) (string, error) {
extractedFiles, err := os.ReadDir(env.extractPath)
if err != nil {
return "", fmt.Errorf("unable to read extracted directory: %w", err)
@ -462,7 +463,7 @@ func executeCommand(cmd *exec.Cmd) error {
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("unable to execute command: %w; error: %s", err, stderr.String())
return fmt.Errorf("unable to execute command %v: %w; error: %s", cmd.String(), err, stderr.String())
}
return nil
}

View file

@ -12,6 +12,7 @@ import (
diskbufferreader "github.com/bill-rich/disk-buffer-reader"
"github.com/stretchr/testify/assert"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
@ -131,7 +132,7 @@ func TestExtractDebContent(t *testing.T) {
assert.Nil(t, err)
defer file.Close()
ctx := context.Background()
ctx := logContext.AddLogger(context.Background())
a := &Archive{}
reader, err := a.extractDebContent(ctx, file)
@ -149,7 +150,7 @@ func TestExtractRPMContent(t *testing.T) {
assert.Nil(t, err)
defer file.Close()
ctx := context.Background()
ctx := logContext.AddLogger(context.Background())
a := &Archive{}
reader, err := a.extractRpmContent(ctx, file)

View file

@ -4,6 +4,9 @@ import (
"context"
"io"
diskbufferreader "github.com/bill-rich/disk-buffer-reader"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
@ -19,7 +22,7 @@ type SpecializedHandler interface {
// HandleSpecialized examines the provided file reader within the context and determines if it is a specialized archive.
// It returns a reader with any necessary modifications, a boolean indicating if the file was specialized,
// and an error if something went wrong during processing.
HandleSpecialized(context.Context, io.Reader) (io.Reader, bool, error)
HandleSpecialized(logContext.Context, io.Reader) (io.Reader, bool, error)
}
type Handler interface {
@ -35,23 +38,40 @@ type Handler interface {
// The function returns true if processing was successful and false otherwise.
// Context is used for cancellation, and the caller is responsible for canceling it if needed.
func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, chunksChan chan *sources.Chunk) bool {
aCtx := logContext.AddLogger(ctx)
for _, h := range DefaultHandlers() {
h.New()
var (
isSpecial bool
err error
)
// The re-reader is used to reset the file reader after checking if the handler implements SpecializedHandler.
// This is necessary because the archive pkg doesn't correctly determine the file type when using
// an io.MultiReader, which is used by the SpecializedHandler.
reReader, err := diskbufferreader.New(file)
if err != nil {
aCtx.Logger().Error(err, "error creating re-reader reader")
return false
}
defer reReader.Close()
// Check if the handler implements SpecializedHandler and process accordingly.
if specialHandler, ok := h.(SpecializedHandler); ok {
if file, isSpecial, err = specialHandler.HandleSpecialized(ctx, file); isSpecial && err == nil {
return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan)
file, isSpecial, err := specialHandler.HandleSpecialized(aCtx, reReader)
if isSpecial {
return handleChunks(aCtx, h.FromFile(ctx, file), chunkSkel, chunksChan)
}
if err != nil {
aCtx.Logger().Error(err, "error handling file")
}
}
if err := reReader.Reset(); err != nil {
aCtx.Logger().Error(err, "error resetting re-reader")
return false
}
reReader.Stop()
var isType bool
if file, isType = h.IsFiletype(ctx, file); isType {
return handleChunks(ctx, h.FromFile(ctx, file), chunkSkel, chunksChan)
if file, isType = h.IsFiletype(aCtx, reReader); isType {
return handleChunks(aCtx, h.FromFile(ctx, file), chunkSkel, chunksChan)
}
}
return false