mirror of
https://github.com/anchore/grype
synced 2024-11-10 06:34:13 +00:00
41cfd42de6
Signed-off-by: Keith Zantow <kzantow@gmail.com>
268 lines
6.3 KiB
Go
268 lines
6.3 KiB
Go
package file
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"path"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/anchore/clio"
|
|
)
|
|
|
|
func TestGetter_GetFile(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
prepareClient func(*http.Client)
|
|
assert assert.ErrorAssertionFunc
|
|
}{
|
|
{
|
|
name: "client trusts server's CA",
|
|
assert: assert.NoError,
|
|
},
|
|
{
|
|
name: "client doesn't trust server's CA",
|
|
prepareClient: removeTrustedCAs,
|
|
assert: assertUnknownAuthorityError,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
requestPath := "/foo"
|
|
|
|
server := newTestServer(t, withResponseForPath(t, requestPath, testFileContent))
|
|
t.Cleanup(server.Close)
|
|
|
|
httpClient := getClient(t, server)
|
|
if tc.prepareClient != nil {
|
|
tc.prepareClient(httpClient)
|
|
}
|
|
|
|
getter := NewGetter(testID, httpClient)
|
|
requestURL := createRequestURL(t, server, requestPath)
|
|
|
|
tempDir := t.TempDir()
|
|
tempFile := path.Join(tempDir, "some-destination-file")
|
|
|
|
err := getter.GetFile(tempFile, requestURL)
|
|
tc.assert(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetter_GetToDir_FilterNonArchivesWired(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
source string
|
|
assert assert.ErrorAssertionFunc
|
|
}{
|
|
{
|
|
name: "error out on non-archive sources",
|
|
source: "http://localhost/something.txt",
|
|
assert: assertErrNonArchiveSource,
|
|
},
|
|
}
|
|
|
|
for _, test := range testCases {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
test.assert(t, NewGetter(testID, nil).GetToDir(t.TempDir(), test.source))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetter_validateHttpSource(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
source string
|
|
assert assert.ErrorAssertionFunc
|
|
}{
|
|
{
|
|
name: "error out on non-archive sources",
|
|
source: "http://localhost/something.txt",
|
|
assert: assertErrNonArchiveSource,
|
|
},
|
|
{
|
|
name: "filter out non-archive sources with get param",
|
|
source: "https://localhost/vulnerability-db_v3_2021-11-21T08:15:44Z.txt?checksum=sha256%3Ac402d01fa909a3fa85a5c6733ef27a3a51a9105b6c62b9152adbd24c08358911",
|
|
assert: assertErrNonArchiveSource,
|
|
},
|
|
{
|
|
name: "ignore non http-https input",
|
|
source: "s3://bucket/something.txt",
|
|
assert: assert.NoError,
|
|
},
|
|
}
|
|
|
|
for _, test := range testCases {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
test.assert(t, validateHTTPSource(test.source))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetter_GetToDir_CertConcerns(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
prepareClient func(*http.Client)
|
|
assert assert.ErrorAssertionFunc
|
|
}{
|
|
|
|
{
|
|
name: "client trusts server's CA",
|
|
assert: assert.NoError,
|
|
},
|
|
{
|
|
name: "client doesn't trust server's CA",
|
|
prepareClient: removeTrustedCAs,
|
|
assert: assertUnknownAuthorityError,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
requestPath := "/foo.tar"
|
|
tarball := createTarball("foo", testFileContent)
|
|
|
|
server := newTestServer(t, withResponseForPath(t, requestPath, tarball))
|
|
t.Cleanup(server.Close)
|
|
|
|
httpClient := getClient(t, server)
|
|
if tc.prepareClient != nil {
|
|
tc.prepareClient(httpClient)
|
|
}
|
|
|
|
getter := NewGetter(testID, httpClient)
|
|
requestURL := createRequestURL(t, server, requestPath)
|
|
|
|
tempDir := t.TempDir()
|
|
|
|
err := getter.GetToDir(tempDir, requestURL)
|
|
tc.assert(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func assertUnknownAuthorityError(t assert.TestingT, err error, _ ...interface{}) bool {
|
|
return assert.ErrorAs(t, err, &x509.UnknownAuthorityError{})
|
|
}
|
|
|
|
func assertErrNonArchiveSource(t assert.TestingT, err error, _ ...interface{}) bool {
|
|
return assert.ErrorIs(t, err, ErrNonArchiveSource)
|
|
}
|
|
|
|
func removeTrustedCAs(client *http.Client) {
|
|
client.Transport.(*http.Transport).TLSClientConfig.RootCAs = x509.NewCertPool()
|
|
}
|
|
|
|
// createTarball makes a single-file tarball and returns it as a byte slice.
|
|
func createTarball(filename string, content []byte) []byte {
|
|
tarBuffer := new(bytes.Buffer)
|
|
tarWriter := tar.NewWriter(tarBuffer)
|
|
tarWriter.WriteHeader(&tar.Header{
|
|
Name: filename,
|
|
Size: int64(len(content)),
|
|
Mode: 0600,
|
|
})
|
|
tarWriter.Write(content)
|
|
tarWriter.Close()
|
|
|
|
return tarBuffer.Bytes()
|
|
}
|
|
|
|
type muxOption func(mux *http.ServeMux)
|
|
|
|
func withResponseForPath(t *testing.T, path string, response []byte) muxOption {
|
|
t.Helper()
|
|
|
|
return func(mux *http.ServeMux) {
|
|
mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
|
|
t.Logf("server handling request: %s %s", req.Method, req.URL)
|
|
|
|
_, err := w.Write(response)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
var testID = clio.Identification{
|
|
Name: "test-app",
|
|
Version: "v0.5.3",
|
|
}
|
|
|
|
func newTestServer(t *testing.T, muxOptions ...muxOption) *httptest.Server {
|
|
t.Helper()
|
|
|
|
mux := http.NewServeMux()
|
|
for _, option := range muxOptions {
|
|
option(mux)
|
|
}
|
|
|
|
server := httptest.NewTLSServer(mux)
|
|
t.Logf("new TLS server listening at %s", getHost(t, server))
|
|
|
|
return server
|
|
}
|
|
|
|
func createRequestURL(t *testing.T, server *httptest.Server, path string) string {
|
|
t.Helper()
|
|
|
|
// TODO: Figure out how to get this value from the server without hardcoding it here
|
|
const testServerCertificateName = "example.com"
|
|
|
|
serverURL, err := url.Parse(server.URL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Set URL hostname to value from TLS certificate
|
|
serverURL.Host = fmt.Sprintf("%s:%s", testServerCertificateName, serverURL.Port())
|
|
|
|
serverURL.Path = path
|
|
|
|
return serverURL.String()
|
|
}
|
|
|
|
// getClient returns an http.Client that can be used to contact the test TLS server.
|
|
func getClient(t *testing.T, server *httptest.Server) *http.Client {
|
|
t.Helper()
|
|
|
|
httpClient := server.Client()
|
|
transport := httpClient.Transport.(*http.Transport)
|
|
|
|
serverHost := getHost(t, server)
|
|
|
|
transport.DialContext = func(_ context.Context, _, addr string) (net.Conn, error) {
|
|
t.Logf("client dialing %q for host %q", serverHost, addr)
|
|
|
|
// Ensure the client dials our test server
|
|
return net.Dial("tcp", serverHost)
|
|
}
|
|
|
|
return httpClient
|
|
}
|
|
|
|
// getHost extracts the host value from a server URL string.
|
|
// e.g. given a server with URL "http://1.2.3.4:5000/foo", getHost returns "1.2.3.4:5000"
|
|
func getHost(t *testing.T, server *httptest.Server) string {
|
|
t.Helper()
|
|
|
|
u, err := url.Parse(server.URL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return u.Hostname() + ":" + u.Port()
|
|
}
|
|
|
|
var testFileContent = []byte("This is the content of a test file!\n")
|