diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..6229e88 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,55 @@ +linters: + presets: + - bugs + - comment + - error + - format + - import + - metalinter + - module + - performance + - sql + - style + - test + - unused + disable: + - exhaustive + - exhaustivestruct + - gochecknoglobals + - gochecknoinits + - godox + - goerr113 + - gomnd + - interfacer + - maligned + - nlreturn + - scopelint + - testpackage + - wrapcheck + +linters-settings: + errcheck: + ignore: "database/sql:Rollback" + gci: + local-prefixes: github.com/dstotijn/hetty + godot: + capital: true + +issues: + exclude-rules: + - linters: + - gosec + # Ignore SHA1 usage. + text: "G(401|505):" + - linters: + - wsl + # Ignore cuddled defer statements. + text: "only one cuddle assignment allowed before defer statement" + - linters: + - nlreturn + # Ignore `break` without leading blank line. + text: "break with no blank line before" + +run: + skip-files: + - cmd/hetty/rice-box.go \ No newline at end of file diff --git a/cmd/hetty/main.go b/cmd/hetty/main.go index d7fbef8..e9e7377 100644 --- a/cmd/hetty/main.go +++ b/cmd/hetty/main.go @@ -2,25 +2,27 @@ package main import ( "crypto/tls" + "errors" "flag" + "fmt" "log" "net" "net/http" "os" "strings" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/playground" rice "github.com/GeertJohan/go.rice" + "github.com/gorilla/mux" + "github.com/mitchellh/go-homedir" + "github.com/dstotijn/hetty/pkg/api" "github.com/dstotijn/hetty/pkg/db/sqlite" "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/proxy" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" - - "github.com/99designs/gqlgen/graphql/handler" - "github.com/99designs/gqlgen/graphql/playground" - "github.com/gorilla/mux" - "github.com/mitchellh/go-homedir" ) var ( @@ -32,8 +34,16 @@ var ( ) func main() { - flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem", "CA certificate filepath. Creates a new CA certificate if file doesn't exist") - flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem", "CA private key filepath. Creates a new CA private key if file doesn't exist") + if err := run(); err != nil { + log.Fatalf("[ERROR]: %v", err) + } +} + +func run() error { + flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem", + "CA certificate filepath. Creates a new CA certificate if file doesn't exist") + flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem", + "CA private key filepath. Creates a new CA private key if file doesn't exist") flag.StringVar(&projPath, "projects", "~/.hetty/projects", "Projects directory path") flag.StringVar(&addr, "addr", ":8080", "TCP address to listen on, in the form \"host:port\"") flag.StringVar(&adminPath, "adminPath", "", "File path to admin build") @@ -42,32 +52,34 @@ func main() { // Expand `~` in filepaths. caCertFile, err := homedir.Expand(caCertFile) if err != nil { - log.Fatalf("[FATAL] Could not parse CA certificate filepath: %v", err) + return fmt.Errorf("could not parse CA certificate filepath: %w", err) } + caKeyFile, err := homedir.Expand(caKeyFile) if err != nil { - log.Fatalf("[FATAL] Could not parse CA private key filepath: %v", err) + return fmt.Errorf("could not parse CA private key filepath: %w", err) } + projPath, err := homedir.Expand(projPath) if err != nil { - log.Fatalf("[FATAL] Could not parse projects filepath: %v", err) + return fmt.Errorf("could not parse projects filepath: %w", err) } // Load existing CA certificate and key from disk, or generate and write // to disk if no files exist yet. caCert, caKey, err := proxy.LoadOrCreateCA(caKeyFile, caCertFile) if err != nil { - log.Fatalf("[FATAL] Could not create/load CA key pair: %v", err) + return fmt.Errorf("could not create/load CA key pair: %w", err) } db, err := sqlite.New(projPath) if err != nil { - log.Fatalf("[FATAL] Could not initialize database client: %v", err) + return fmt.Errorf("could not initialize database client: %w", err) } projService, err := proj.NewService(db) if err != nil { - log.Fatalf("[FATAL] Could not create new project service: %v", err) + return fmt.Errorf("could not create new project service: %w", err) } defer projService.Close() @@ -81,19 +93,21 @@ func main() { p, err := proxy.NewProxy(caCert, caKey) if err != nil { - log.Fatalf("[FATAL] Could not create Proxy: %v", err) + return fmt.Errorf("could not create proxy: %w", err) } p.UseRequestModifier(reqLogService.RequestModifier) p.UseResponseModifier(reqLogService.ResponseModifier) var adminHandler http.Handler + if adminPath == "" { // Used for embedding with `rice`. box, err := rice.FindBox("../../admin/dist") if err != nil { - log.Fatalf("[FATAL] Could not find embedded admin resources: %v", err) + return fmt.Errorf("could not find embedded admin resources: %w", err) } + adminHandler = http.FileServer(box.HTTPBox()) } else { adminHandler = http.FileServer(http.Dir(adminPath)) @@ -109,11 +123,12 @@ func main() { // GraphQL server. adminRouter.Path("/api/playground/").Handler(playground.Handler("GraphQL Playground", "/api/graphql/")) - adminRouter.Path("/api/graphql/").Handler(handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{ - RequestLogService: reqLogService, - ProjectService: projService, - ScopeService: scope, - }}))) + adminRouter.Path("/api/graphql/").Handler( + handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{ + RequestLogService: reqLogService, + ProjectService: projService, + ScopeService: scope, + }}))) // Admin interface. adminRouter.PathPrefix("").Handler(adminHandler) @@ -128,8 +143,11 @@ func main() { } log.Printf("[INFO] Running server on %v ...", addr) + err = s.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - log.Fatalf("[FATAL] HTTP server closed: %v", err) + if err != nil && errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http server closed unexpected: %w", err) } + + return nil } diff --git a/pkg/api/resolvers.go b/pkg/api/resolvers.go index 2e0da14..27c88ce 100644 --- a/pkg/api/resolvers.go +++ b/pkg/api/resolvers.go @@ -4,16 +4,18 @@ package api import ( "context" + "errors" "fmt" "regexp" "strings" "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/v2/gqlerror" + "github.com/dstotijn/hetty/pkg/proj" "github.com/dstotijn/hetty/pkg/reqlog" "github.com/dstotijn/hetty/pkg/scope" "github.com/dstotijn/hetty/pkg/search" - "github.com/vektah/gqlparser/v2/gqlerror" ) type Resolver struct { @@ -22,20 +24,22 @@ type Resolver struct { ScopeService *scope.Scope } -type queryResolver struct{ *Resolver } -type mutationResolver struct{ *Resolver } +type ( + queryResolver struct{ *Resolver } + mutationResolver struct{ *Resolver } +) func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) { reqs, err := r.RequestLogService.FindRequests(ctx) - if err == proj.ErrNoProject { + if errors.Is(err, proj.ErrNoProject) { return nil, noActiveProjectErr(ctx) + } else if err != nil { + return nil, fmt.Errorf("could not query repository for requests: %w", err) } - if err != nil { - return nil, fmt.Errorf("could not query repository for requests: %v", err) - } + logs := make([]HTTPRequestLog, len(reqs)) for i, req := range reqs { @@ -43,6 +47,7 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, if err != nil { return nil, err } + logs[i] = req } @@ -51,12 +56,12 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) { log, err := r.RequestLogService.FindRequestLogByID(ctx, id) - if err == reqlog.ErrRequestNotFound { + if errors.Is(err, reqlog.ErrRequestNotFound) { return nil, nil + } else if err != nil { + return nil, fmt.Errorf("could not get request by ID: %w", err) } - if err != nil { - return nil, fmt.Errorf("could not get request by ID: %v", err) - } + req, err := parseRequestLog(log) if err != nil { return nil, err @@ -89,6 +94,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) { if req.Request.Header != nil { log.Headers = make([]HTTPHeader, 0) + for key, values := range req.Request.Header { for _, value := range values { log.Headers = append(log.Headers, HTTPHeader{ @@ -106,15 +112,19 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) { StatusCode: req.Response.Response.StatusCode, } statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2) + if len(statusReasonSubs) == 2 { log.Response.StatusReason = statusReasonSubs[1] } + if len(req.Response.Body) > 0 { resBody := string(req.Response.Body) log.Response.Body = &resBody } + if req.Response.Response.Header != nil { log.Response.Headers = make([]HTTPHeader, 0) + for key, values := range req.Response.Response.Header { for _, value := range values { log.Response.Headers = append(log.Response.Headers, HTTPHeader{ @@ -131,12 +141,12 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) { func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) { p, err := r.ProjectService.Open(ctx, name) - if err == proj.ErrInvalidName { + if errors.Is(err, proj.ErrInvalidName) { return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.") + } else if err != nil { + return nil, fmt.Errorf("could not open project: %w", err) } - if err != nil { - return nil, fmt.Errorf("could not open project: %v", err) - } + return &Project{ Name: p.Name, IsActive: p.IsActive, @@ -145,11 +155,10 @@ func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Proje func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) { p, err := r.ProjectService.ActiveProject() - if err == proj.ErrNoProject { + if errors.Is(err, proj.ErrNoProject) { return nil, nil - } - if err != nil { - return nil, fmt.Errorf("could not open project: %v", err) + } else if err != nil { + return nil, fmt.Errorf("could not open project: %w", err) } return &Project{ @@ -161,7 +170,7 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) { func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) { p, err := r.ProjectService.Projects() if err != nil { - return nil, fmt.Errorf("could not get projects: %v", err) + return nil, fmt.Errorf("could not get projects: %w", err) } projects := make([]Project, len(p)) @@ -184,21 +193,25 @@ func regexpToStringPtr(r *regexp.Regexp) *string { if r == nil { return nil } + s := r.String() + return &s } func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) { if err := r.ProjectService.Close(); err != nil { - return nil, fmt.Errorf("could not close project: %v", err) + return nil, fmt.Errorf("could not close project: %w", err) } + return &CloseProjectResult{true}, nil } func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) { if err := r.ProjectService.Delete(name); err != nil { - return nil, fmt.Errorf("could not delete project: %v", err) + return nil, fmt.Errorf("could not delete project: %w", err) } + return &DeleteProjectResult{ Success: true, }, nil @@ -206,33 +219,40 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) { if err := r.RequestLogService.ClearRequests(ctx); err != nil { - return nil, fmt.Errorf("could not clear request log: %v", err) + return nil, fmt.Errorf("could not clear request log: %w", err) } + return &ClearHTTPRequestLogResult{true}, nil } func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) { rules := make([]scope.Rule, len(input)) + for i, rule := range input { u, err := stringPtrToRegexp(rule.URL) if err != nil { - return nil, fmt.Errorf("invalid URL in scope rule: %v", err) + return nil, fmt.Errorf("invalid URL in scope rule: %w", err) } + var headerKey, headerValue *regexp.Regexp + if rule.Header != nil { headerKey, err = stringPtrToRegexp(rule.Header.Key) if err != nil { - return nil, fmt.Errorf("invalid header key in scope rule: %v", err) + return nil, fmt.Errorf("invalid header key in scope rule: %w", err) } + headerValue, err = stringPtrToRegexp(rule.Header.Key) if err != nil { - return nil, fmt.Errorf("invalid header value in scope rule: %v", err) + return nil, fmt.Errorf("invalid header value in scope rule: %w", err) } } + body, err := stringPtrToRegexp(rule.Body) if err != nil { - return nil, fmt.Errorf("invalid body in scope rule: %v", err) + return nil, fmt.Errorf("invalid body in scope rule: %w", err) } + rules[i] = scope.Rule{ URL: u, Header: scope.Header{ @@ -244,7 +264,7 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) } if err := r.ScopeService.SetRules(ctx, rules); err != nil { - return nil, fmt.Errorf("could not set scope: %v", err) + return nil, fmt.Errorf("could not set scope: %w", err) } return scopeToScopeRules(rules), nil @@ -260,14 +280,14 @@ func (r *mutationResolver) SetHTTPRequestLogFilter( ) (*HTTPRequestLogFilter, error) { filter, err := findRequestsFilterFromInput(input) if err != nil { - return nil, fmt.Errorf("could not parse request log filter: %v", err) + return nil, fmt.Errorf("could not parse request log filter: %w", err) } + err = r.RequestLogService.SetRequestLogFilter(ctx, filter) - if err == proj.ErrNoProject { + if errors.Is(err, proj.ErrNoProject) { return nil, noActiveProjectErr(ctx) - } - if err != nil { - return nil, fmt.Errorf("could not set request log filter: %v", err) + } else if err != nil { + return nil, fmt.Errorf("could not set request log filter: %w", err) } return findReqFilterToHTTPReqLogFilter(filter), nil @@ -277,6 +297,7 @@ func stringPtrToRegexp(s *string) (*regexp.Regexp, error) { if s == nil { return nil, nil } + return regexp.Compile(*s) } @@ -290,8 +311,10 @@ func scopeToScopeRules(rules []scope.Rule) []ScopeRule { Value: regexpToStringPtr(rule.Header.Value), } } + scopeRules[i].Body = regexpToStringPtr(rule.Body) } + return scopeRules } @@ -299,14 +322,17 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo if input == nil { return } + if input.OnlyInScope != nil { filter.OnlyInScope = *input.OnlyInScope } + if input.SearchExpression != nil && *input.SearchExpression != "" { expr, err := search.ParseQuery(*input.SearchExpression) if err != nil { - return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %v", err) + return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err) } + filter.RawSearchExpr = *input.SearchExpression filter.SearchExpr = expr } @@ -319,6 +345,7 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H if findReqFilter == empty { return nil } + httpReqLogFilter := &HTTPRequestLogFilter{ OnlyInScope: findReqFilter.OnlyInScope, } diff --git a/pkg/db/sqlite/dto.go b/pkg/db/sqlite/dto.go index f3acac5..764864d 100644 --- a/pkg/db/sqlite/dto.go +++ b/pkg/db/sqlite/dto.go @@ -43,7 +43,7 @@ func (u *reqURL) Scan(value interface{}) error { parsed, err := url.Parse(rawURL) if err != nil { - return fmt.Errorf("sqlite: could not parse URL: %v", err) + return fmt.Errorf("sqlite: could not parse URL: %w", err) } *u = reqURL(*parsed) diff --git a/pkg/db/sqlite/search.go b/pkg/db/sqlite/search.go index 0eb4cdd..bdcc299 100644 --- a/pkg/db/sqlite/search.go +++ b/pkg/db/sqlite/search.go @@ -6,6 +6,7 @@ import ( "sort" sq "github.com/Masterminds/squirrel" + "github.com/dstotijn/hetty/pkg/search" ) @@ -57,20 +58,24 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) { if err != nil { return nil, err } + right, err := parseSearchExpr(expr.Right) if err != nil { return nil, err } + return sq.And{left, right}, nil case search.TokOpOr: left, err := parseSearchExpr(expr.Left) if err != nil { return nil, err } + right, err := parseSearchExpr(expr.Right) if err != nil { return nil, err } + return sq.Or{left, right}, nil } @@ -78,6 +83,7 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) { if !ok { return nil, errors.New("left operand must be a string literal") } + right, ok := expr.Right.(*search.StringLiteral) if !ok { return nil, errors.New("right operand must be a string literal") @@ -113,14 +119,17 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) { func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) { // Sorting is not necessary, but makes it easier to do assertions in tests. sortedKeys := make([]string, 0, len(stringLiteralMap)) + for _, v := range stringLiteralMap { sortedKeys = append(sortedKeys, v) } + sort.Strings(sortedKeys) or := make(sq.Or, len(stringLiteralMap)) for i, value := range sortedKeys { or[i] = sq.Like{value: "%" + strLiteral.Value + "%"} } + return or, nil } diff --git a/pkg/db/sqlite/search_test.go b/pkg/db/sqlite/search_test.go index 15f5169..7b4ccf1 100644 --- a/pkg/db/sqlite/search_test.go +++ b/pkg/db/sqlite/search_test.go @@ -10,6 +10,8 @@ import ( ) func TestParseSearchExpr(t *testing.T) { + t.Parallel() + tests := []struct { name string searchExpr search.Expression @@ -206,6 +208,8 @@ func TestParseSearchExpr(t *testing.T) { } func assertError(t *testing.T, exp, got error) { + t.Helper() + switch { case exp == nil && got != nil: t.Fatalf("expected: nil, got: %v", got) diff --git a/pkg/db/sqlite/sqlite.go b/pkg/db/sqlite/sqlite.go index b6a7827..9faf511 100644 --- a/pkg/db/sqlite/sqlite.go +++ b/pkg/db/sqlite/sqlite.go @@ -15,14 +15,14 @@ import ( "strings" "time" - "github.com/dstotijn/hetty/pkg/proj" - "github.com/dstotijn/hetty/pkg/reqlog" - "github.com/dstotijn/hetty/pkg/scope" - "github.com/99designs/gqlgen/graphql" sq "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "github.com/mattn/go-sqlite3" + + "github.com/dstotijn/hetty/pkg/proj" + "github.com/dstotijn/hetty/pkg/reqlog" + "github.com/dstotijn/hetty/pkg/scope" ) var regexpFn = func(pattern string, value interface{}) (bool, error) { @@ -55,10 +55,7 @@ type httpRequestLogsQuery struct { func init() { sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { - if err := conn.RegisterFunc("regexp", regexpFn, false); err != nil { - return err - } - return nil + return conn.RegisterFunc("regexp", regexpFn, false) }, }) } @@ -66,9 +63,10 @@ func init() { func New(dbPath string) (*Client, error) { if _, err := os.Stat(dbPath); os.IsNotExist(err) { if err := os.MkdirAll(dbPath, 0755); err != nil { - return nil, fmt.Errorf("proj: could not create project directory: %v", err) + return nil, fmt.Errorf("proj: could not create project directory: %w", err) } } + return &Client{ dbPath: dbPath, }, nil @@ -85,17 +83,18 @@ func (c *Client) OpenProject(name string) error { dbPath := filepath.Join(c.dbPath, name+".db") dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode()) + db, err := sqlx.Open("sqlite3_with_regexp", dsn) if err != nil { - return fmt.Errorf("sqlite: could not open database: %v", err) + return fmt.Errorf("sqlite: could not open database: %w", err) } if err := db.Ping(); err != nil { - return fmt.Errorf("sqlite: could not ping database: %v", err) + return fmt.Errorf("sqlite: could not ping database: %w", err) } if err := prepareSchema(db); err != nil { - return fmt.Errorf("sqlite: could not prepare schema: %v", err) + return fmt.Errorf("sqlite: could not prepare schema: %w", err) } c.db = db @@ -107,10 +106,11 @@ func (c *Client) OpenProject(name string) error { func (c *Client) Projects() ([]proj.Project, error) { files, err := ioutil.ReadDir(c.dbPath) if err != nil { - return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err) + return nil, fmt.Errorf("sqlite: could not read projects directory: %w", err) } projects := make([]proj.Project, len(files)) + for i, file := range files { projName := strings.TrimSuffix(file.Name(), ".db") projects[i] = proj.Project{ @@ -132,7 +132,7 @@ func prepareSchema(db *sqlx.DB) error { timestamp DATETIME )`) if err != nil { - return fmt.Errorf("could not create http_requests table: %v", err) + return fmt.Errorf("could not create http_requests table: %w", err) } _, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses ( @@ -145,7 +145,7 @@ func prepareSchema(db *sqlx.DB) error { timestamp DATETIME )`) if err != nil { - return fmt.Errorf("could not create http_responses table: %v", err) + return fmt.Errorf("could not create http_responses table: %w", err) } _, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers ( @@ -156,7 +156,7 @@ func prepareSchema(db *sqlx.DB) error { value TEXT )`) if err != nil { - return fmt.Errorf("could not create http_headers table: %v", err) + return fmt.Errorf("could not create http_headers table: %w", err) } _, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings ( @@ -164,7 +164,7 @@ func prepareSchema(db *sqlx.DB) error { settings TEXT )`) if err != nil { - return fmt.Errorf("could not create settings table: %v", err) + return fmt.Errorf("could not create settings table: %w", err) } return nil @@ -175,8 +175,9 @@ func (c *Client) Close() error { if c.db == nil { return nil } + if err := c.db.Close(); err != nil { - return fmt.Errorf("sqlite: could not close database: %v", err) + return fmt.Errorf("sqlite: could not close database: %w", err) } c.db = nil @@ -187,7 +188,7 @@ func (c *Client) Close() error { func (c *Client) DeleteProject(name string) error { if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil { - return fmt.Errorf("sqlite: could not remove database file: %v", err) + return fmt.Errorf("sqlite: could not remove database file: %w", err) } return nil @@ -219,10 +220,12 @@ func (c *Client) ClearRequestLogs(ctx context.Context) error { if c.db == nil { return proj.ErrNoProject } + _, err := c.db.Exec("DELETE FROM http_requests") if err != nil { - return fmt.Errorf("sqlite: could not delete requests: %v", err) + return fmt.Errorf("sqlite: could not delete requests: %w", err) } + return nil } @@ -247,11 +250,13 @@ func (c *Client) FindRequestLogs( if filter.OnlyInScope && scope != nil { var ruleExpr []sq.Sqlizer + for _, rule := range scope.Rules() { if rule.URL != nil { ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String())) } } + if len(ruleExpr) > 0 { reqQuery = reqQuery.Where(sq.Or(ruleExpr)) } @@ -260,37 +265,41 @@ func (c *Client) FindRequestLogs( if filter.SearchExpr != nil { sqlizer, err := parseSearchExpr(filter.SearchExpr) if err != nil { - return nil, fmt.Errorf("sqlite: could not parse search expression: %v", err) + return nil, fmt.Errorf("sqlite: could not parse search expression: %w", err) } + reqQuery = reqQuery.Where(sqlizer) } sql, args, err := reqQuery.ToSql() if err != nil { - return nil, fmt.Errorf("sqlite: could not parse query: %v", err) + return nil, fmt.Errorf("sqlite: could not parse query: %w", err) } rows, err := c.db.QueryxContext(ctx, sql, args...) if err != nil { - return nil, fmt.Errorf("sqlite: could not execute query: %v", err) + return nil, fmt.Errorf("sqlite: could not execute query: %w", err) } defer rows.Close() for rows.Next() { var dto httpRequest + err = rows.StructScan(&dto) if err != nil { - return nil, fmt.Errorf("sqlite: could not scan row: %v", err) + return nil, fmt.Errorf("sqlite: could not scan row: %w", err) } + reqLogs = append(reqLogs, dto.toRequestLog()) } + if err := rows.Err(); err != nil { - return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err) + return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err) } - rows.Close() + defer rows.Close() if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil { - return nil, fmt.Errorf("sqlite: could not query headers: %v", err) + return nil, fmt.Errorf("sqlite: could not query headers: %w", err) } return reqLogs, nil @@ -300,35 +309,38 @@ func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Reque if c.db == nil { return reqlog.Request{}, proj.ErrNoProject } - httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx) + httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx) reqQuery := sq. Select(httpReqLogsQuery.requestCols...). From("http_requests req"). Where("req.id = ?") + if httpReqLogsQuery.joinResponse { reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id") } reqSQL, _, err := reqQuery.ToSql() if err != nil { - return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %v", err) + return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %w", err) } row := c.db.QueryRowxContext(ctx, reqSQL, id) - var dto httpRequest - err = row.StructScan(&dto) - if err == sql.ErrNoRows { - return reqlog.Request{}, reqlog.ErrRequestNotFound - } - if err != nil { - return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %v", err) - } - reqLog := dto.toRequestLog() + var dto httpRequest + + err = row.StructScan(&dto) + if errors.Is(err, sql.ErrNoRows) { + return reqlog.Request{}, reqlog.ErrRequestNotFound + } else if err != nil { + return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %w", err) + } + + reqLog := dto.toRequestLog() reqLogs := []reqlog.Request{reqLog} + if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil { - return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %v", err) + return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %w", err) } return reqLogs[0], nil @@ -352,8 +364,9 @@ func (c *Client) AddRequestLog( tx, err := c.db.BeginTxx(ctx, nil) if err != nil { - return nil, fmt.Errorf("sqlite: could not start transaction: %v", err) + return nil, fmt.Errorf("sqlite: could not start transaction: %w", err) } + defer tx.Rollback() reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests ( @@ -364,7 +377,7 @@ func (c *Client) AddRequestLog( timestamp ) VALUES (?, ?, ?, ?, ?)`) if err != nil { - return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err) + return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err) } defer reqStmt.Close() @@ -376,13 +389,14 @@ func (c *Client) AddRequestLog( reqLog.Timestamp, ) if err != nil { - return nil, fmt.Errorf("sqlite: could not execute statement: %v", err) + return nil, fmt.Errorf("sqlite: could not execute statement: %w", err) } reqID, err := result.LastInsertId() if err != nil { - return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err) + return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err) } + reqLog.ID = reqID headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers ( @@ -391,17 +405,17 @@ func (c *Client) AddRequestLog( value ) VALUES (?, ?, ?)`) if err != nil { - return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err) + return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err) } defer headerStmt.Close() err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header) if err != nil { - return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err) + return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err) } if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err) + return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err) } return reqLog, nil @@ -424,9 +438,10 @@ func (c *Client) AddResponseLog( Body: body, Timestamp: timestamp, } + tx, err := c.db.BeginTx(ctx, nil) if err != nil { - return nil, fmt.Errorf("sqlite: could not start transaction: %v", err) + return nil, fmt.Errorf("sqlite: could not start transaction: %w", err) } defer tx.Rollback() @@ -439,7 +454,7 @@ func (c *Client) AddResponseLog( timestamp ) VALUES (?, ?, ?, ?, ?, ?)`) if err != nil { - return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err) + return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err) } defer resStmt.Close() @@ -457,13 +472,14 @@ func (c *Client) AddResponseLog( resLog.Timestamp, ) if err != nil { - return nil, fmt.Errorf("sqlite: could not execute statement: %v", err) + return nil, fmt.Errorf("sqlite: could not execute statement: %w", err) } resID, err := result.LastInsertId() if err != nil { - return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err) + return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err) } + resLog.ID = resID headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers ( @@ -472,17 +488,17 @@ func (c *Client) AddResponseLog( value ) VALUES (?, ?, ?)`) if err != nil { - return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err) + return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err) } defer headerStmt.Close() err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header) if err != nil { - return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err) + return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err) } if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err) + return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err) } return resLog, nil @@ -496,14 +512,14 @@ func (c *Client) UpsertSettings(ctx context.Context, module string, settings int jsonSettings, err := json.Marshal(settings) if err != nil { - return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err) + return fmt.Errorf("sqlite: could not encode settings as JSON: %w", err) } _, err = c.db.ExecContext(ctx, `INSERT INTO settings (module, settings) VALUES (?, ?) ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings) if err != nil { - return fmt.Errorf("sqlite: could not insert scope settings: %v", err) + return fmt.Errorf("sqlite: could not insert scope settings: %w", err) } return nil @@ -515,17 +531,18 @@ func (c *Client) FindSettingsByModule(ctx context.Context, module string, settin } var jsonSettings []byte + row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module) + err := row.Scan(&jsonSettings) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return proj.ErrNoSettings - } - if err != nil { - return fmt.Errorf("sqlite: could not scan row: %v", err) + } else if err != nil { + return fmt.Errorf("sqlite: could not scan row: %w", err) } if err := json.Unmarshal(jsonSettings, &settings); err != nil { - return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err) + return fmt.Errorf("sqlite: could not decode settings from JSON: %w", err) } return nil @@ -535,42 +552,46 @@ func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.H for key, values := range headers { for _, value := range values { if _, err := stmt.ExecContext(ctx, id, key, value); err != nil { - return fmt.Errorf("could not execute statement: %v", err) + return fmt.Errorf("could not execute statement: %w", err) } } } + return nil } func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) { headers := make(http.Header) + rows, err := stmt.QueryContext(ctx, id) if err != nil { - return nil, fmt.Errorf("sqlite: could not execute query: %v", err) + return nil, fmt.Errorf("sqlite: could not execute query: %w", err) } defer rows.Close() for rows.Next() { var key, value string - err := rows.Scan( - &key, - &value, - ) + + err := rows.Scan(&key, &value) if err != nil { - return nil, fmt.Errorf("sqlite: could not scan row: %v", err) + return nil, fmt.Errorf("sqlite: could not scan row: %w", err) } + headers.Add(key, value) } + if err := rows.Err(); err != nil { - return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err) + return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err) } return headers, nil } func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery { - var joinResponse bool - var reqHeaderCols, resHeaderCols []string + var ( + joinResponse bool + reqHeaderCols, resHeaderCols []string + ) opCtx := graphql.GetOperationContext(ctx) reqFields := graphql.CollectFieldsCtx(ctx, nil) @@ -580,6 +601,7 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery { if col, ok := reqFieldToColumnMap[reqField.Name]; ok { reqCols = append(reqCols, "req."+col) } + if reqField.Name == "headers" { headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil) for _, headerField := range headerFields { @@ -588,19 +610,23 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery { } } } + if reqField.Name == "response" { joinResponse = true resFields := graphql.CollectFields(opCtx, reqField.Selections, nil) + for _, resField := range resFields { if resField.Name == "headers" { reqCols = append(reqCols, "res.id AS res_id") headerFields := graphql.CollectFields(opCtx, resField.Selections, nil) + for _, headerField := range headerFields { if col, ok := headerFieldToColumnMap[headerField.Name]; ok { resHeaderCols = append(resHeaderCols, col) } } } + if col, ok := resFieldToColumnMap[resField.Name]; ok { reqCols = append(reqCols, "res."+col) } @@ -627,18 +653,21 @@ func (c *Client) queryHeaders( From("http_headers").Where("req_id = ?"). ToSql() if err != nil { - return fmt.Errorf("could not parse request headers query: %v", err) + return fmt.Errorf("could not parse request headers query: %w", err) } + reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery) if err != nil { - return fmt.Errorf("could not prepare statement: %v", err) + return fmt.Errorf("could not prepare statement: %w", err) } defer reqHeadersStmt.Close() + for i := range reqLogs { headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID) if err != nil { - return fmt.Errorf("could not query request headers: %v", err) + return fmt.Errorf("could not query request headers: %w", err) } + reqLogs[i].Request.Header = headers } } @@ -649,21 +678,25 @@ func (c *Client) queryHeaders( From("http_headers").Where("res_id = ?"). ToSql() if err != nil { - return fmt.Errorf("could not parse response headers query: %v", err) + return fmt.Errorf("could not parse response headers query: %w", err) } + resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery) if err != nil { - return fmt.Errorf("could not prepare statement: %v", err) + return fmt.Errorf("could not prepare statement: %w", err) } defer resHeadersStmt.Close() + for i := range reqLogs { if reqLogs[i].Response == nil { continue } + headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID) if err != nil { - return fmt.Errorf("could not query response headers: %v", err) + return fmt.Errorf("could not query response headers: %w", err) } + reqLogs[i].Response.Response.Header = headers } } diff --git a/pkg/proj/proj.go b/pkg/proj/proj.go index df69a42..e26b666 100644 --- a/pkg/proj/proj.go +++ b/pkg/proj/proj.go @@ -9,8 +9,10 @@ import ( "sync" ) -type OnProjectOpenFn func(name string) error -type OnProjectCloseFn func(name string) error +type ( + OnProjectOpenFn func(name string) error + OnProjectCloseFn func(name string) error +) // Service is used for managing projects. type Service struct { @@ -47,8 +49,9 @@ func (svc *Service) Close() error { defer svc.mu.Unlock() closedProject := svc.activeProject + if err := svc.repo.Close(); err != nil { - return fmt.Errorf("proj: could not close project: %v", err) + return fmt.Errorf("proj: could not close project: %w", err) } svc.activeProject = "" @@ -63,12 +66,13 @@ func (svc *Service) Delete(name string) error { if name == "" { return errors.New("proj: name cannot be empty") } + if svc.activeProject == name { return fmt.Errorf("proj: project (%v) is active", name) } if err := svc.repo.DeleteProject(name); err != nil { - return fmt.Errorf("proj: could not delete project: %v", err) + return fmt.Errorf("proj: could not delete project: %w", err) } return nil @@ -85,11 +89,11 @@ func (svc *Service) Open(ctx context.Context, name string) (Project, error) { defer svc.mu.Unlock() if err := svc.repo.Close(); err != nil { - return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err) + return Project{}, fmt.Errorf("proj: could not close previously open database: %w", err) } if err := svc.repo.OpenProject(name); err != nil { - return Project{}, fmt.Errorf("proj: could not open database: %v", err) + return Project{}, fmt.Errorf("proj: could not open database: %w", err) } svc.activeProject = name @@ -115,7 +119,7 @@ func (svc *Service) ActiveProject() (Project, error) { func (svc *Service) Projects() ([]Project, error) { projects, err := svc.repo.Projects() if err != nil { - return nil, fmt.Errorf("proj: could not get projects: %v", err) + return nil, fmt.Errorf("proj: could not get projects: %w", err) } return projects, nil diff --git a/pkg/proxy/cert.go b/pkg/proxy/cert.go index 262cb66..fa7fb5b 100644 --- a/pkg/proxy/cert.go +++ b/pkg/proxy/cert.go @@ -25,7 +25,7 @@ import ( var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20)) // CertConfig is a set of configuration values that are used to build TLS configs -// capable of MITM +// capable of MITM. type CertConfig struct { ca *x509.Certificate caPriv crypto.PrivateKey @@ -40,6 +40,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf if err != nil { return nil, err } + pub := priv.Public() // Subject Key Identifier support for end entity certificate. @@ -48,6 +49,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf if err != nil { return nil, err } + h := sha1.New() h.Write(pkixPubKey) keyID := h.Sum(nil) @@ -67,58 +69,69 @@ func LoadOrCreateCA(caKeyFile, caCertFile string) (*x509.Certificate, *rsa.Priva if err == nil { caCert, err := x509.ParseCertificate(tlsCA.Certificate[0]) if err != nil { - return nil, nil, fmt.Errorf("proxy: could not parse CA: %v", err) + return nil, nil, fmt.Errorf("proxy: could not parse CA: %w", err) } + caKey, ok := tlsCA.PrivateKey.(*rsa.PrivateKey) if !ok { return nil, nil, errors.New("proxy: private key is not RSA") } + return caCert, caKey, nil } + if !os.IsNotExist(err) { - return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %v", err) + return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %w", err) } // Create directories for files if they don't exist yet. keyDir, _ := filepath.Split(caKeyFile) if keyDir != "" { if _, err := os.Stat(keyDir); os.IsNotExist(err) { - os.MkdirAll(keyDir, 0755) + if err := os.MkdirAll(keyDir, 0755); err != nil { + return nil, nil, fmt.Errorf("proxy: could not create directory for CA key: %w", err) + } } } + keyDir, _ = filepath.Split(caCertFile) if keyDir != "" { if _, err := os.Stat("keyDir"); os.IsNotExist(err) { - os.MkdirAll(keyDir, 0755) + if err := os.MkdirAll(keyDir, 0755); err != nil { + return nil, nil, fmt.Errorf("proxy: could not create directory for CA cert: %w", err) + } } } // Create new CA keypair. - caCert, caKey, err := NewCA("Hetty", "Hetty CA", time.Duration(365*24*time.Hour)) + caCert, caKey, err := NewCA("Hetty", "Hetty CA", 365*24*time.Hour) if err != nil { - return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %v", err) + return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %w", err) } // Open CA certificate and key files for writing. certOut, err := os.Create(caCertFile) if err != nil { - return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %v", err) + return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %w", err) } + keyOut, err := os.OpenFile(caKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { - return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %v", err) + return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %w", err) } // Write PEM blocks to CA certificate and key files. if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}); err != nil { - return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %v", err) + return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %w", err) } + privBytes, err := x509.MarshalPKCS8PrivateKey(caKey) if err != nil { - return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %v", err) + return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %w", err) } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { - return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %v", err) + return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %w", err) } return caCert, caKey, nil @@ -130,6 +143,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate if err != nil { return nil, nil, err } + pub := priv.Public() // Subject Key Identifier support for end entity certificate. @@ -138,6 +152,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate if err != nil { return nil, nil, err } + h := sha1.New() h.Write(pkixpub) keyID := h.Sum(nil) @@ -187,8 +202,10 @@ func (c *CertConfig) TLSConfig() *tls.Config { if clientHello.ServerName == "" { return nil, errors.New("missing server name (SNI)") } + return c.cert(clientHello.ServerName) }, + MinVersion: tls.VersionTLS12, NextProtos: []string{"http/1.1"}, } } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 8cec97a..a19400b 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -5,13 +5,12 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "errors" "fmt" "log" "net" "net/http" "net/http/httputil" - - "github.com/dstotijn/hetty/pkg/scope" ) type contextKey int @@ -27,8 +26,6 @@ type Proxy struct { // TODO: Add mutex for modifier funcs. reqModifiers []RequestModifyMiddleware resModifiers []ResponseModifyMiddleware - - scope *scope.Scope } // NewProxy returns a new Proxy. @@ -55,7 +52,7 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodConnect { - p.handleConnect(w, r) + p.handleConnect(w) return } @@ -103,11 +100,12 @@ func (p *Proxy) modifyResponse(res *http.Response) error { // handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel. // During the TLS handshake with the client, we use the proxy's CA config to // create a certificate on-the-fly. -func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { +func (p *Proxy) handleConnect(w http.ResponseWriter) { hj, ok := w.(http.Hijacker) if !ok { log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w) - writeError(w, r, http.StatusServiceUnavailable) + writeError(w, http.StatusServiceUnavailable) + return } @@ -116,7 +114,8 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { clientConn, _, err := hj.Hijack() if err != nil { log.Printf("[ERROR] Hijacking client connection failed: %v", err) - writeError(w, r, http.StatusServiceUnavailable) + writeError(w, http.StatusServiceUnavailable) + return } defer clientConn.Close() @@ -127,14 +126,15 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { log.Printf("[ERROR] Securing client connection failed: %v", err) return } - clientConnNotify := ConnNotify{clientConn, make(chan struct{})} + clientConnNotify := ConnNotify{clientConn, make(chan struct{})} l := &OnceAcceptListener{clientConnNotify.Conn} err = http.Serve(l, p) - if err != nil && err != ErrAlreadyAccepted { + if err != nil && !errors.Is(err, ErrAlreadyAccepted) { log.Printf("[ERROR] Serving HTTP request failed: %v", err) } + <-clientConnNotify.closed } @@ -144,20 +144,22 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) { tlsConn := tls.Server(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { tlsConn.Close() - return nil, fmt.Errorf("handshake error: %v", err) + return nil, fmt.Errorf("handshake error: %w", err) } return tlsConn, nil } func errorHandler(w http.ResponseWriter, r *http.Request, err error) { - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return } + log.Printf("[ERROR]: Proxy error: %v", err) + w.WriteHeader(http.StatusBadGateway) } -func writeError(w http.ResponseWriter, r *http.Request, code int) { +func writeError(w http.ResponseWriter, code int) { http.Error(w, http.StatusText(code), code) } diff --git a/pkg/reqlog/repo.go b/pkg/reqlog/repo.go index 7126a43..96d1b32 100644 --- a/pkg/reqlog/repo.go +++ b/pkg/reqlog/repo.go @@ -16,7 +16,7 @@ type Repository interface { FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error) FindRequestLogByID(ctx context.Context, id int64) (Request, error) AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error) - AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) + AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) // nolint:lll ClearRequestLogs(ctx context.Context) error UpsertSettings(ctx context.Context, module string, settings interface{}) error FindSettingsByModule(ctx context.Context, module string, settings interface{}) error diff --git a/pkg/reqlog/reqlog.go b/pkg/reqlog/reqlog.go index 7d933a7..363939f 100644 --- a/pkg/reqlog/reqlog.go +++ b/pkg/reqlog/reqlog.go @@ -24,9 +24,7 @@ const LogBypassedKey contextKey = 0 const moduleName = "reqlog" -var ( - ErrRequestNotFound = errors.New("reqlog: request not found") -) +var ErrRequestNotFound = errors.New("reqlog: request not found") type Request struct { ID int64 @@ -74,12 +72,13 @@ func NewService(cfg Config) *Service { cfg.ProjectService.OnProjectOpen(func(_ string) error { err := svc.loadSettings() - if err == proj.ErrNoSettings { + if errors.Is(err, proj.ErrNoSettings) { return nil } if err != nil { - return fmt.Errorf("reqlog: could not load settings: %v", err) + return fmt.Errorf("reqlog: could not load settings: %w", err) } + return nil }) cfg.ProjectService.OnProjectClose(func(_ string) error { @@ -126,12 +125,13 @@ func (svc *Service) addResponse( if res.Header.Get("Content-Encoding") == "gzip" { gzipReader, err := gzip.NewReader(bytes.NewBuffer(body)) if err != nil { - return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err) + return nil, fmt.Errorf("reqlog: could not create gzip reader: %w", err) } defer gzipReader.Close() + body, err = ioutil.ReadAll(gzipReader) if err != nil { - return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err) + return nil, fmt.Errorf("reqlog: could not read gzipped response body: %w", err) } } @@ -141,18 +141,23 @@ func (svc *Service) addResponse( func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc { return func(req *http.Request) { now := time.Now() + next(req) clone := req.Clone(req.Context()) + var body []byte + if req.Body != nil { // TODO: Use io.LimitReader. var err error + body, err = ioutil.ReadAll(req.Body) if err != nil { log.Printf("[ERROR] Could not read request body for logging: %v", err) return } + req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) } @@ -161,19 +166,21 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) { ctx := context.WithValue(req.Context(), LogBypassedKey, true) *req = *req.WithContext(ctx) + return } reqLog, err := svc.addRequest(req.Context(), *clone, body, now) - if err == proj.ErrNoProject { + if errors.Is(err, proj.ErrNoProject) { ctx := context.WithValue(req.Context(), LogBypassedKey, true) *req = *req.WithContext(ctx) + return - } - if err != nil { + } else if err != nil { log.Printf("[ERROR] Could not store request log: %v", err) return } + ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID) *req = *req.WithContext(ctx) } @@ -182,6 +189,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc { return func(res *http.Response) error { now := time.Now() + if err := next(res); err != nil { return err } @@ -200,8 +208,9 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon // TODO: Use io.LimitReader. body, err := ioutil.ReadAll(res.Body) if err != nil { - return fmt.Errorf("reqlog: could not read response body: %v", err) + return fmt.Errorf("reqlog: could not read response body: %w", err) } + res.Body = ioutil.NopCloser(bytes.NewBuffer(body)) go func() { @@ -220,6 +229,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error { OnlyInScope bool RawSearchExpr string } + if err := json.Unmarshal(b, &dto); err != nil { return err } @@ -234,6 +244,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error { if err != nil { return err } + filter.SearchExpr = expr } diff --git a/pkg/scope/scope.go b/pkg/scope/scope.go index 9d8dd03..2bcb5f0 100644 --- a/pkg/scope/scope.go +++ b/pkg/scope/scope.go @@ -3,6 +3,7 @@ package scope import ( "context" "encoding/json" + "errors" "fmt" "net/http" "regexp" @@ -38,12 +39,13 @@ func New(repo Repository, projService *proj.Service) *Scope { projService.OnProjectOpen(func(_ string) error { err := s.load(context.Background()) - if err == proj.ErrNoSettings { + if errors.Is(err, proj.ErrNoSettings) { return nil } if err != nil { - return fmt.Errorf("scope: could not load scope: %v", err) + return fmt.Errorf("scope: could not load scope: %w", err) } + return nil }) projService.OnProjectClose(func(_ string) error { @@ -57,6 +59,7 @@ func New(repo Repository, projService *proj.Service) *Scope { func (s *Scope) Rules() []Rule { s.mu.RLock() defer s.mu.RUnlock() + return s.rules } @@ -65,12 +68,12 @@ func (s *Scope) load(ctx context.Context) error { defer s.mu.Unlock() var rules []Rule + err := s.repo.FindSettingsByModule(ctx, moduleName, &rules) - if err == proj.ErrNoSettings { + if errors.Is(err, proj.ErrNoSettings) { return err - } - if err != nil { - return fmt.Errorf("scope: could not load scope settings: %v", err) + } else if err != nil { + return fmt.Errorf("scope: could not load scope settings: %w", err) } s.rules = rules @@ -89,7 +92,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error { defer s.mu.Unlock() if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil { - return fmt.Errorf("scope: cannot set rules in repository: %v", err) + return fmt.Errorf("scope: cannot set rules in repository: %w", err) } s.rules = rules @@ -100,6 +103,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error { func (s *Scope) Match(req *http.Request, body []byte) bool { s.mu.RLock() defer s.mu.RUnlock() + for _, rule := range s.rules { if matches := rule.Match(req, body); matches { return true @@ -118,11 +122,13 @@ func (r Rule) Match(req *http.Request, body []byte) bool { for key, values := range req.Header { var keyMatches, valueMatches bool + if r.Header.Key != nil { if matches := r.Header.Key.MatchString(key); matches { keyMatches = true } } + if r.Header.Value != nil { for _, value := range values { if matches := r.Header.Value.MatchString(value); matches { @@ -154,15 +160,17 @@ func (r Rule) Match(req *http.Request, body []byte) bool { // MarshalJSON implements json.Marshaler. func (r Rule) MarshalJSON() ([]byte, error) { - type headerDTO struct { - Key string - Value string - } - type ruleDTO struct { - URL string - Header headerDTO - Body string - } + type ( + headerDTO struct { + Key string + Value string + } + ruleDTO struct { + URL string + Header headerDTO + Body string + } + ) dto := ruleDTO{ URL: regexpToString(r.URL), @@ -178,15 +186,17 @@ func (r Rule) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (r *Rule) UnmarshalJSON(data []byte) error { - type headerDTO struct { - Key string - Value string - } - type ruleDTO struct { - URL string - Header headerDTO - Body string - } + type ( + headerDTO struct { + Key string + Value string + } + ruleDTO struct { + URL string + Header headerDTO + Body string + } + ) var dto ruleDTO if err := json.Unmarshal(data, &dto); err != nil { @@ -197,14 +207,17 @@ func (r *Rule) UnmarshalJSON(data []byte) error { if err != nil { return err } + headerKey, err := stringToRegexp(dto.Header.Key) if err != nil { return err } + headerValue, err := stringToRegexp(dto.Header.Value) if err != nil { return err } + body, err := stringToRegexp(dto.Body) if err != nil { return err @@ -226,6 +239,7 @@ func regexpToString(r *regexp.Regexp) string { if r == nil { return "" } + return r.String() } @@ -233,5 +247,6 @@ func stringToRegexp(s string) (*regexp.Regexp, error) { if s == "" { return nil, nil } + return regexp.Compile(s) } diff --git a/pkg/search/ast.go b/pkg/search/ast.go index 7a4647c..a0598a8 100644 --- a/pkg/search/ast.go +++ b/pkg/search/ast.go @@ -11,7 +11,6 @@ type PrefixExpression struct { Right Expression } -func (pe *PrefixExpression) expressionNode() {} func (pe *PrefixExpression) String() string { b := strings.Builder{} b.WriteString("(") @@ -29,7 +28,6 @@ type InfixExpression struct { Right Expression } -func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) String() string { b := strings.Builder{} b.WriteString("(") @@ -47,7 +45,6 @@ type StringLiteral struct { Value string } -func (sl *StringLiteral) expressionNode() {} func (sl *StringLiteral) String() string { return sl.Value } diff --git a/pkg/search/lexer.go b/pkg/search/lexer.go index 2af09c8..4eb0c48 100644 --- a/pkg/search/lexer.go +++ b/pkg/search/lexer.go @@ -17,21 +17,21 @@ const eof = 0 // Token types. const ( - // Flow + // Flow. TokInvalid TokenType = iota TokEOF TokParenOpen TokParenClose - // Literals + // Literals. TokString - // Boolean operators + // Boolean operators. TokOpNot TokOpAnd TokOpOr - // Comparison operators + // Comparison operators. TokOpEq TokOpNotEq TokOpGt @@ -98,6 +98,7 @@ func (tt TokenType) String() string { if typeString, ok := tokenTypeStrings[tt]; ok { return typeString } + return "" } @@ -113,6 +114,7 @@ func (l *Lexer) read() (r rune) { l.width = 0 return eof } + r, l.width = utf8.DecodeRuneInString(l.input[l.pos:]) l.pos += l.width @@ -124,6 +126,7 @@ func (l *Lexer) emit(tokenType TokenType) { Type: tokenType, Literal: l.input[l.start:l.pos], } + l.start = l.pos } @@ -159,6 +162,7 @@ func begin(l *Lexer) stateFn { l.backup() l.emit(TokOpEq) } + return begin case '!': switch next := l.read(); next { @@ -169,6 +173,7 @@ func begin(l *Lexer) stateFn { default: return l.errorf("invalid rune %v", r) } + return begin case '<': if next := l.read(); next == '=' { @@ -177,6 +182,7 @@ func begin(l *Lexer) stateFn { l.backup() l.emit(TokOpLt) } + return begin case '>': if next := l.read(); next == '=' { @@ -185,6 +191,7 @@ func begin(l *Lexer) stateFn { l.backup() l.emit(TokOpGt) } + return begin case '(': l.emit(TokParenOpen) @@ -231,15 +238,18 @@ func unquotedString(l *Lexer) stateFn { case r == eof: l.backup() l.emitUnquotedString() + return begin case unicode.IsSpace(r): l.backup() l.emitUnquotedString() l.skip() + return begin case isReserved(r): l.backup() l.emitUnquotedString() + return begin } } @@ -251,6 +261,7 @@ func (l *Lexer) emitUnquotedString() { l.emit(tokType) return } + l.emit(TokString) } @@ -260,5 +271,6 @@ func isReserved(r rune) bool { return true } } + return false } diff --git a/pkg/search/lexer_test.go b/pkg/search/lexer_test.go index 8811cd4..c51217a 100644 --- a/pkg/search/lexer_test.go +++ b/pkg/search/lexer_test.go @@ -3,6 +3,8 @@ package search import "testing" func TestNextToken(t *testing.T) { + t.Parallel() + tests := []struct { name string input string diff --git a/pkg/search/parser.go b/pkg/search/parser.go index 96e73f5..a3642f5 100644 --- a/pkg/search/parser.go +++ b/pkg/search/parser.go @@ -18,8 +18,10 @@ const ( precGroup ) -type prefixParser func(*Parser) (Expression, error) -type infixParser func(*Parser, Expression) (Expression, error) +type ( + prefixParser func(*Parser) (Expression, error) + infixParser func(*Parser, Expression) (Expression, error) +) var ( prefixParsers = map[TokenType]prefixParser{} @@ -77,7 +79,6 @@ func NewParser(l *Lexer) *Parser { p.nextToken() return p - } func ParseQuery(input string) (expr Expression, err error) { @@ -91,18 +92,20 @@ func ParseQuery(input string) (expr Expression, err error) { for !p.curTokenIs(TokEOF) { right, err := p.parseExpression(precLowest) - if err != nil { - return nil, fmt.Errorf("search: could not parse expression: %v", err) - } - if expr == nil { + + switch { + case err != nil: + return nil, fmt.Errorf("search: could not parse expression: %w", err) + case expr == nil: expr = right - } else { + default: expr = &InfixExpression{ Operator: TokOpAnd, Left: expr, Right: right, } } + p.nextToken() } @@ -122,18 +125,11 @@ func (p *Parser) peekTokenIs(t TokenType) bool { return p.peek.Type == t } -func (p *Parser) expectPeek(t TokenType) error { - if !p.peekTokenIs(t) { - return fmt.Errorf("expected next token to be %v, got %v", t, p.peek.Type) - } - p.nextToken() - return nil -} - func (p *Parser) curPrecedence() precedence { if p, ok := tokenPrecedences[p.cur.Type]; ok { return p } + return precLowest } @@ -141,6 +137,7 @@ func (p *Parser) peekPrecedence() precedence { if p, ok := tokenPrecedences[p.peek.Type]; ok { return p } + return precLowest } @@ -152,7 +149,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) { expr, err := prefixParser(p) if err != nil { - return nil, fmt.Errorf("could not parse expression prefix: %v", err) + return nil, fmt.Errorf("could not parse expression prefix: %w", err) } for !p.peekTokenIs(eof) && prec < p.peekPrecedence() { @@ -165,7 +162,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) { expr, err = infixParser(p, expr) if err != nil { - return nil, fmt.Errorf("could not parse infix expression: %v", err) + return nil, fmt.Errorf("could not parse infix expression: %w", err) } } @@ -181,8 +178,9 @@ func parsePrefixExpression(p *Parser) (Expression, error) { right, err := p.parseExpression(precPrefix) if err != nil { - return nil, fmt.Errorf("could not parse expression for right operand: %v", err) + return nil, fmt.Errorf("could not parse expression for right operand: %w", err) } + expr.Right = right return expr, nil @@ -199,8 +197,9 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) { right, err := p.parseExpression(prec) if err != nil { - return nil, fmt.Errorf("could not parse expression for right operand: %v", err) + return nil, fmt.Errorf("could not parse expression for right operand: %w", err) } + expr.Right = right return expr, nil @@ -215,17 +214,19 @@ func parseGroupedExpression(p *Parser) (Expression, error) { expr, err := p.parseExpression(precLowest) if err != nil { - return nil, fmt.Errorf("could not parse grouped expression: %v", err) + return nil, fmt.Errorf("could not parse grouped expression: %w", err) } for p.nextToken(); !p.curTokenIs(TokParenClose); p.nextToken() { if p.curTokenIs(TokEOF) { return nil, fmt.Errorf("unexpected EOF: unmatched parentheses") } + right, err := p.parseExpression(precLowest) if err != nil { - return nil, fmt.Errorf("could not parse expression: %v", err) + return nil, fmt.Errorf("could not parse expression: %w", err) } + expr = &InfixExpression{ Operator: TokOpAnd, Left: expr, diff --git a/pkg/search/parser_test.go b/pkg/search/parser_test.go index 0a8eab2..eeb0dcd 100644 --- a/pkg/search/parser_test.go +++ b/pkg/search/parser_test.go @@ -7,6 +7,8 @@ import ( ) func TestParseQuery(t *testing.T) { + t.Parallel() + tests := []struct { name string input string @@ -233,6 +235,8 @@ func TestParseQuery(t *testing.T) { } func assertError(t *testing.T, exp, got error) { + t.Helper() + switch { case exp == nil && got != nil: t.Fatalf("expected: nil, got: %v", got)