grype/internal/file/getter_test.go
Dan Luhring 1e020d7ea0
Detect when a user specifies an empty SBOM (#695)
* Detect when the user specifies empty SBOM file

Signed-off-by: Dan Luhring <dan+github@luhrings.com>

* Fix darwin cert verification failure from Go 1.18

Signed-off-by: Dan Luhring <dan+github@luhrings.com>
2022-03-24 10:12:11 -04:00

261 lines
6.2 KiB
Go

package file
import (
"archive/tar"
"bytes"
"context"
"crypto/x509"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetter_GetFile(t *testing.T) {
testCases := []struct {
name string
prepareClient func(*http.Client)
assert assert.ErrorAssertionFunc
}{
{
name: "client trusts server's CA",
assert: assert.NoError,
},
{
name: "client doesn't trust server's CA",
prepareClient: removeTrustedCAs,
assert: assertUnknownAuthorityError,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestPath := "/foo"
server := newTestServer(t, withResponseForPath(t, requestPath, testFileContent))
t.Cleanup(server.Close)
httpClient := getClient(t, server)
if tc.prepareClient != nil {
tc.prepareClient(httpClient)
}
getter := NewGetter(httpClient)
requestURL := createRequestURL(t, server, requestPath)
tempDir := t.TempDir()
tempFile := path.Join(tempDir, "some-destination-file")
err := getter.GetFile(tempFile, requestURL)
tc.assert(t, err)
})
}
}
func TestGetter_GetToDir_FilterNonArchivesWired(t *testing.T) {
testCases := []struct {
name string
source string
assert assert.ErrorAssertionFunc
}{
{
name: "error out on non-archive sources",
source: "http://localhost/something.txt",
assert: assertErrNonArchiveSource,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
test.assert(t, NewGetter(nil).GetToDir(t.TempDir(), test.source))
})
}
}
func TestGetter_validateHttpSource(t *testing.T) {
testCases := []struct {
name string
source string
assert assert.ErrorAssertionFunc
}{
{
name: "error out on non-archive sources",
source: "http://localhost/something.txt",
assert: assertErrNonArchiveSource,
},
{
name: "filter out non-archive sources with get param",
source: "https://localhost/vulnerability-db_v3_2021-11-21T08:15:44Z.txt?checksum=sha256%3Ac402d01fa909a3fa85a5c6733ef27a3a51a9105b6c62b9152adbd24c08358911",
assert: assertErrNonArchiveSource,
},
{
name: "ignore non http-https input",
source: "s3://bucket/something.txt",
assert: assert.NoError,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
test.assert(t, validateHTTPSource(test.source))
})
}
}
func TestGetter_GetToDir_CertConcerns(t *testing.T) {
testCases := []struct {
name string
prepareClient func(*http.Client)
assert assert.ErrorAssertionFunc
}{
{
name: "client trusts server's CA",
assert: assert.NoError,
},
{
name: "client doesn't trust server's CA",
prepareClient: removeTrustedCAs,
assert: assertUnknownAuthorityError,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestPath := "/foo.tar"
tarball := createTarball("foo", testFileContent)
server := newTestServer(t, withResponseForPath(t, requestPath, tarball))
t.Cleanup(server.Close)
httpClient := getClient(t, server)
if tc.prepareClient != nil {
tc.prepareClient(httpClient)
}
getter := NewGetter(httpClient)
requestURL := createRequestURL(t, server, requestPath)
tempDir := t.TempDir()
err := getter.GetToDir(tempDir, requestURL)
tc.assert(t, err)
})
}
}
func assertUnknownAuthorityError(t assert.TestingT, err error, _ ...interface{}) bool {
return assert.ErrorAs(t, err, &x509.UnknownAuthorityError{})
}
func assertErrNonArchiveSource(t assert.TestingT, err error, _ ...interface{}) bool {
return assert.ErrorIs(t, err, ErrNonArchiveSource)
}
func removeTrustedCAs(client *http.Client) {
client.Transport.(*http.Transport).TLSClientConfig.RootCAs = x509.NewCertPool()
}
// createTarball makes a single-file tarball and returns it as a byte slice.
func createTarball(filename string, content []byte) []byte {
tarBuffer := new(bytes.Buffer)
tarWriter := tar.NewWriter(tarBuffer)
tarWriter.WriteHeader(&tar.Header{
Name: filename,
Size: int64(len(content)),
Mode: 0600,
})
tarWriter.Write(content)
tarWriter.Close()
return tarBuffer.Bytes()
}
type muxOption func(mux *http.ServeMux)
func withResponseForPath(t *testing.T, path string, response []byte) muxOption {
t.Helper()
return func(mux *http.ServeMux) {
mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
t.Logf("server handling request: %s %s", req.Method, req.URL)
_, err := w.Write(response)
if err != nil {
t.Fatal(err)
}
})
}
}
func newTestServer(t *testing.T, muxOptions ...muxOption) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
for _, option := range muxOptions {
option(mux)
}
server := httptest.NewTLSServer(mux)
t.Logf("new TLS server listening at %s", getHost(t, server))
return server
}
func createRequestURL(t *testing.T, server *httptest.Server, path string) string {
t.Helper()
// TODO: Figure out how to get this value from the server without hardcoding it here
const testServerCertificateName = "example.com"
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
// Set URL hostname to value from TLS certificate
serverURL.Host = fmt.Sprintf("%s:%s", testServerCertificateName, serverURL.Port())
serverURL.Path = path
return serverURL.String()
}
// getClient returns an http.Client that can be used to contact the test TLS server.
func getClient(t *testing.T, server *httptest.Server) *http.Client {
t.Helper()
httpClient := server.Client()
transport := httpClient.Transport.(*http.Transport)
serverHost := getHost(t, server)
transport.DialContext = func(_ context.Context, _, addr string) (net.Conn, error) {
t.Logf("client dialing %q for host %q", serverHost, addr)
// Ensure the client dials our test server
return net.Dial("tcp", serverHost)
}
return httpClient
}
// getHost extracts the host value from a server URL string.
// e.g. given a server with URL "http://1.2.3.4:5000/foo", getHost returns "1.2.3.4:5000"
func getHost(t *testing.T, server *httptest.Server) string {
t.Helper()
u, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
return u.Hostname() + ":" + u.Port()
}
var testFileContent = []byte("This is the content of a test file!\n")