safely join paths derived from archive headers

Signed-off-by: Alex Goodman <alex.goodman@anchore.com>
This commit is contained in:
Alex Goodman 2021-04-14 15:10:09 -04:00
parent d5dfaaba53
commit 484730435b
No known key found for this signature in database
GPG key ID: 5CB45AE22BAB7EA7
2 changed files with 106 additions and 9 deletions

View file

@ -24,6 +24,15 @@ const (
const perFileReadLimit = 2 * GB
type errZipSlipDetected struct {
Prefix string
JoinArgs []string
}
func (e *errZipSlipDetected) Error() string {
return fmt.Sprintf("paths are not allowed to resolve outside of the root prefix (%q). Destination: %q", e.Prefix, e.JoinArgs)
}
type zipTraversalRequest map[string]struct{}
func newZipTraverseRequest(paths ...string) zipTraversalRequest {
@ -169,17 +178,12 @@ func ContentsFromZip(archivePath string, paths ...string) (map[string]string, er
// UnzipToDir extracts a zip archive to a target directory.
func UnzipToDir(archivePath, targetDir string) error {
visitor := func(file *zip.File) error {
// the zip-slip attack protection is still being erroneously detected
// nolint:gosec
expandedFilePath := filepath.Clean(filepath.Join(targetDir, file.Name))
// protect against zip slip attacks (traversing unintended parent paths from maliciously crafted relative-path entries)
if !strings.HasPrefix(expandedFilePath, filepath.Clean(targetDir)+string(os.PathSeparator)) {
return fmt.Errorf("potential zip slip attack: %q", expandedFilePath)
joinedPath, err := safeJoin(targetDir, file.Name)
if err != nil {
return err
}
err := extractSingleFile(file, expandedFilePath, archivePath)
if err != nil {
if err = extractSingleFile(file, joinedPath, archivePath); err != nil {
return err
}
return nil
@ -188,6 +192,20 @@ func UnzipToDir(archivePath, targetDir string) error {
return TraverseFilesInZip(archivePath, visitor)
}
// safeJoin ensures that any destinations do not resolve to a path above the prefix path.
func safeJoin(prefix string, dest ...string) (string, error) {
joinResult := filepath.Join(append([]string{prefix}, dest...)...)
cleanJoinResult := filepath.Clean(joinResult)
if !strings.HasPrefix(cleanJoinResult, filepath.Clean(prefix)) {
return "", &errZipSlipDetected{
Prefix: prefix,
JoinArgs: dest,
}
}
// why not return the clean path? the called may not be expected it from what should only be a join operation.
return joinResult, nil
}
func extractSingleFile(file *zip.File, expandedFilePath, archivePath string) error {
zippedFile, err := file.Open()
if err != nil {

View file

@ -3,6 +3,8 @@ package file
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
@ -12,6 +14,7 @@ import (
"testing"
"github.com/go-test/deep"
"github.com/stretchr/testify/assert"
)
func equal(r1, r2 io.Reader) (bool, error) {
@ -173,3 +176,79 @@ func TestContentsFromZip(t *testing.T) {
t.Errorf("full result: %s", string(b))
}
}
// looks like there isn't a helper for this yet? https://github.com/stretchr/testify/issues/497
func assertErrorAs(expectedErr interface{}) assert.ErrorAssertionFunc {
return func(t assert.TestingT, actualErr error, i ...interface{}) bool {
return errors.As(actualErr, &expectedErr)
}
}
func TestSafeJoin(t *testing.T) {
tests := []struct {
prefix string
args []string
expected string
errAssertion assert.ErrorAssertionFunc
}{
// go cases...
{
prefix: "/a/place",
args: []string{
"somewhere/else",
},
expected: "/a/place/somewhere/else",
errAssertion: assert.NoError,
},
{
prefix: "/a/place",
args: []string{
"somewhere/../else",
},
expected: "/a/place/else",
errAssertion: assert.NoError,
},
{
prefix: "/a/../place",
args: []string{
"somewhere/else",
},
expected: "/place/somewhere/else",
errAssertion: assert.NoError,
},
// zip slip examples....
{
prefix: "/a/place",
args: []string{
"../../../etc/passwd",
},
expected: "",
errAssertion: assertErrorAs(&errZipSlipDetected{}),
},
{
prefix: "/a/place",
args: []string{
"../",
"../",
},
expected: "",
errAssertion: assertErrorAs(&errZipSlipDetected{}),
},
{
prefix: "/a/place",
args: []string{
"../",
},
expected: "",
errAssertion: assertErrorAs(&errZipSlipDetected{}),
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%+v:%+v", test.prefix, test.args), func(t *testing.T) {
actual, err := safeJoin(test.prefix, test.args...)
test.errAssertion(t, err)
assert.Equal(t, test.expected, actual)
})
}
}