[analyze] Add client filter to detect successful unsafe HTTP requests (#3305)

* Move analyzer client to its own file

* Add analyzer client filter to detect successful unsafe HTTP requests

* Close response body in test
This commit is contained in:
Miccah 2024-09-18 10:31:21 -07:00 committed by GitHub
parent 1b59a5ecf2
commit b2da2a6a5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 221 additions and 66 deletions

View file

@ -3,14 +3,10 @@ package analyzers
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/fatih/color"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/pb/analyzerpb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
@ -181,68 +177,6 @@ var YellowWriter = color.New(color.FgYellow).SprintFunc()
var RedWriter = color.New(color.FgRed).SprintFunc()
var DefaultWriter = color.New().SprintFunc()
type AnalyzeClient struct {
http.Client
LoggingEnabled bool
LogFile string
}
func CreateLogFileName(baseName string) string {
// Get the current time
currentTime := time.Now()
// Format the time as "2024_06_30_07_15_30"
timeString := currentTime.Format("2006_01_02_15_04_05")
// Create the log file name
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
return logFileName
}
func NewAnalyzeClient(cfg *config.Config) *http.Client {
if cfg == nil || !cfg.LoggingEnabled {
return &http.Client{}
}
return &http.Client{
Transport: LoggingRoundTripper{
parent: http.DefaultTransport,
logFile: cfg.LogFile,
},
}
}
type LoggingRoundTripper struct {
parent http.RoundTripper
// TODO: io.Writer
logFile string
}
func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
startTime := time.Now()
resp, err := r.parent.RoundTrip(req)
if err != nil {
return resp, err
}
// TODO: JSON
logEntry := fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n", startTime.Format(time.RFC3339), req.Method, req.URL.Path, resp.StatusCode)
// Open log file in append mode.
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return resp, fmt.Errorf("failed to open log file: %w", err)
}
defer file.Close()
// Write log entry to file.
if _, err := file.WriteString(logEntry); err != nil {
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
}
return resp, nil
}
// BindAllPermissions creates a Binding for each permission to the given
// resource.
func BindAllPermissions(r Resource, perms ...Permission) []Binding {

View file

@ -0,0 +1,119 @@
package analyzers
import (
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
)
type AnalyzeClient struct {
http.Client
LoggingEnabled bool
LogFile string
}
func CreateLogFileName(baseName string) string {
// Get the current time
currentTime := time.Now()
// Format the time as "2024_06_30_07_15_30"
timeString := currentTime.Format("2006_01_02_15_04_05")
// Create the log file name
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
return logFileName
}
func NewAnalyzeClient(cfg *config.Config) *http.Client {
client := &http.Client{
Transport: AnalyzerRoundTripper{parent: http.DefaultTransport},
}
if cfg == nil || !cfg.LoggingEnabled {
return client
}
return &http.Client{
Transport: LoggingRoundTripper{
parent: client.Transport,
logFile: cfg.LogFile,
},
}
}
type LoggingRoundTripper struct {
parent http.RoundTripper
// TODO: io.Writer
logFile string
}
func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
startTime := time.Now()
resp, parentErr := r.parent.RoundTrip(req)
if resp == nil {
return resp, parentErr
}
// TODO: JSON
var logEntry string
if parentErr != nil {
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d, Error: %s\n",
startTime.Format(time.RFC3339),
req.Method,
req.URL.Path,
resp.StatusCode,
parentErr.Error(),
)
} else {
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n",
startTime.Format(time.RFC3339),
req.Method,
req.URL.Path,
resp.StatusCode,
)
}
// Open log file in append mode.
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return resp, fmt.Errorf("failed to open log file: %w", err)
}
defer file.Close()
// Write log entry to file.
if _, err := file.WriteString(logEntry); err != nil {
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
}
return resp, parentErr
}
type AnalyzerRoundTripper struct {
parent http.RoundTripper
}
func (r AnalyzerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := r.parent.RoundTrip(req)
if err != nil || methodIsSafe(req.Method) {
return resp, err
}
// Check that unsafe methods did NOT return a valid status code.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, fmt.Errorf("non-safe request returned success")
}
return resp, nil
}
// methodIsSafe is a helper method to check whether the HTTP method is safe according to MDN Web Docs.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods#safe_idempotent_and_cacheable_request_methods
func methodIsSafe(method string) bool {
switch strings.ToUpper(method) {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}

View file

@ -0,0 +1,102 @@
package analyzers
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestAnalyzerClientUnsafeSuccess(t *testing.T) {
testCases := []struct {
name string
method string
expectedStatus int
expectedError bool
}{
{
name: "Safe method (GET)",
method: http.MethodGet,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (HEAD)",
method: http.MethodHead,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (OPTIONS)",
method: http.MethodOptions,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Safe method (TRACE)",
method: http.MethodTrace,
expectedStatus: http.StatusOK,
expectedError: false,
},
{
name: "Unsafe method (POST) with success status",
method: http.MethodPost,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (PUT) with success status",
method: http.MethodPut,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (DELETE) with success status",
method: http.MethodDelete,
expectedStatus: http.StatusOK,
expectedError: true,
},
{
name: "Unsafe method (POST) with error status",
method: http.MethodPost,
expectedStatus: http.StatusInternalServerError,
expectedError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a test server that returns the expected status code
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.expectedStatus)
}))
defer server.Close()
// Create a test request
req, err := http.NewRequest(tc.method, server.URL, nil)
if err != nil {
t.Fatalf("Failed to create test request: %v", err)
}
// Create the AnalyzerRoundTripper with a test client
client := NewAnalyzeClient(nil)
// Perform the request
resp, err := client.Do(req)
if resp != nil {
_ = resp.Body.Close()
}
// Check the error
if err != nil && !tc.expectedError {
t.Errorf("Unexpected error: %v", err)
} else if err == nil && tc.expectedError {
t.Errorf("Expected error, but got nil")
}
// Check the response status code
if resp != nil && resp.StatusCode != tc.expectedStatus {
t.Errorf("Expected status code: %d, but got: %d", tc.expectedStatus, resp.StatusCode)
}
})
}
}