mirror of
https://github.com/trufflesecurity/trufflehog.git
synced 2024-11-10 07:04:24 +00:00
[analyze] Add client filter to detect successful unsafe HTTP requests (#3305)
* Move analyzer client to its own file * Add analyzer client filter to detect successful unsafe HTTP requests * Close response body in test
This commit is contained in:
parent
1b59a5ecf2
commit
b2da2a6a5c
3 changed files with 221 additions and 66 deletions
|
@ -3,14 +3,10 @@ package analyzers
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/pb/analyzerpb"
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
|
||||
)
|
||||
|
@ -181,68 +177,6 @@ var YellowWriter = color.New(color.FgYellow).SprintFunc()
|
|||
var RedWriter = color.New(color.FgRed).SprintFunc()
|
||||
var DefaultWriter = color.New().SprintFunc()
|
||||
|
||||
type AnalyzeClient struct {
|
||||
http.Client
|
||||
LoggingEnabled bool
|
||||
LogFile string
|
||||
}
|
||||
|
||||
func CreateLogFileName(baseName string) string {
|
||||
// Get the current time
|
||||
currentTime := time.Now()
|
||||
|
||||
// Format the time as "2024_06_30_07_15_30"
|
||||
timeString := currentTime.Format("2006_01_02_15_04_05")
|
||||
|
||||
// Create the log file name
|
||||
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
|
||||
return logFileName
|
||||
}
|
||||
|
||||
func NewAnalyzeClient(cfg *config.Config) *http.Client {
|
||||
if cfg == nil || !cfg.LoggingEnabled {
|
||||
return &http.Client{}
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: LoggingRoundTripper{
|
||||
parent: http.DefaultTransport,
|
||||
logFile: cfg.LogFile,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type LoggingRoundTripper struct {
|
||||
parent http.RoundTripper
|
||||
// TODO: io.Writer
|
||||
logFile string
|
||||
}
|
||||
|
||||
func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
resp, err := r.parent.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// TODO: JSON
|
||||
logEntry := fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n", startTime.Format(time.RFC3339), req.Method, req.URL.Path, resp.StatusCode)
|
||||
|
||||
// Open log file in append mode.
|
||||
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write log entry to file.
|
||||
if _, err := file.WriteString(logEntry); err != nil {
|
||||
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// BindAllPermissions creates a Binding for each permission to the given
|
||||
// resource.
|
||||
func BindAllPermissions(r Resource, perms ...Permission) []Binding {
|
||||
|
|
119
pkg/analyzer/analyzers/client.go
Normal file
119
pkg/analyzer/analyzers/client.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package analyzers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
|
||||
)
|
||||
|
||||
type AnalyzeClient struct {
|
||||
http.Client
|
||||
LoggingEnabled bool
|
||||
LogFile string
|
||||
}
|
||||
|
||||
func CreateLogFileName(baseName string) string {
|
||||
// Get the current time
|
||||
currentTime := time.Now()
|
||||
|
||||
// Format the time as "2024_06_30_07_15_30"
|
||||
timeString := currentTime.Format("2006_01_02_15_04_05")
|
||||
|
||||
// Create the log file name
|
||||
logFileName := fmt.Sprintf("%s_%s.log", timeString, baseName)
|
||||
return logFileName
|
||||
}
|
||||
|
||||
func NewAnalyzeClient(cfg *config.Config) *http.Client {
|
||||
client := &http.Client{
|
||||
Transport: AnalyzerRoundTripper{parent: http.DefaultTransport},
|
||||
}
|
||||
if cfg == nil || !cfg.LoggingEnabled {
|
||||
return client
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: LoggingRoundTripper{
|
||||
parent: client.Transport,
|
||||
logFile: cfg.LogFile,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type LoggingRoundTripper struct {
|
||||
parent http.RoundTripper
|
||||
// TODO: io.Writer
|
||||
logFile string
|
||||
}
|
||||
|
||||
func (r LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
resp, parentErr := r.parent.RoundTrip(req)
|
||||
if resp == nil {
|
||||
return resp, parentErr
|
||||
}
|
||||
|
||||
// TODO: JSON
|
||||
var logEntry string
|
||||
if parentErr != nil {
|
||||
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d, Error: %s\n",
|
||||
startTime.Format(time.RFC3339),
|
||||
req.Method,
|
||||
req.URL.Path,
|
||||
resp.StatusCode,
|
||||
parentErr.Error(),
|
||||
)
|
||||
} else {
|
||||
logEntry = fmt.Sprintf("Date: %s, Method: %s, Path: %s, Status: %d\n",
|
||||
startTime.Format(time.RFC3339),
|
||||
req.Method,
|
||||
req.URL.Path,
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
// Open log file in append mode.
|
||||
file, err := os.OpenFile(r.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write log entry to file.
|
||||
if _, err := file.WriteString(logEntry); err != nil {
|
||||
return resp, fmt.Errorf("failed to write log entry to file: %w", err)
|
||||
}
|
||||
|
||||
return resp, parentErr
|
||||
}
|
||||
|
||||
type AnalyzerRoundTripper struct {
|
||||
parent http.RoundTripper
|
||||
}
|
||||
|
||||
func (r AnalyzerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := r.parent.RoundTrip(req)
|
||||
if err != nil || methodIsSafe(req.Method) {
|
||||
return resp, err
|
||||
}
|
||||
// Check that unsafe methods did NOT return a valid status code.
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return resp, fmt.Errorf("non-safe request returned success")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// methodIsSafe is a helper method to check whether the HTTP method is safe according to MDN Web Docs.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods#safe_idempotent_and_cacheable_request_methods
|
||||
func methodIsSafe(method string) bool {
|
||||
switch strings.ToUpper(method) {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
102
pkg/analyzer/analyzers/client_test.go
Normal file
102
pkg/analyzer/analyzers/client_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package analyzers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAnalyzerClientUnsafeSuccess(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
expectedStatus int
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Safe method (GET)",
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Safe method (HEAD)",
|
||||
method: http.MethodHead,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Safe method (OPTIONS)",
|
||||
method: http.MethodOptions,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Safe method (TRACE)",
|
||||
method: http.MethodTrace,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Unsafe method (POST) with success status",
|
||||
method: http.MethodPost,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsafe method (PUT) with success status",
|
||||
method: http.MethodPut,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsafe method (DELETE) with success status",
|
||||
method: http.MethodDelete,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsafe method (POST) with error status",
|
||||
method: http.MethodPost,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a test server that returns the expected status code
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.expectedStatus)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a test request
|
||||
req, err := http.NewRequest(tc.method, server.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test request: %v", err)
|
||||
}
|
||||
|
||||
// Create the AnalyzerRoundTripper with a test client
|
||||
client := NewAnalyzeClient(nil)
|
||||
|
||||
// Perform the request
|
||||
resp, err := client.Do(req)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// Check the error
|
||||
if err != nil && !tc.expectedError {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
} else if err == nil && tc.expectedError {
|
||||
t.Errorf("Expected error, but got nil")
|
||||
}
|
||||
|
||||
// Check the response status code
|
||||
if resp != nil && resp.StatusCode != tc.expectedStatus {
|
||||
t.Errorf("Expected status code: %d, but got: %d", tc.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue