Add linter, fix linting issue

This commit is contained in:
David Stotijn 2021-04-25 16:23:53 +02:00
parent ad3dc0da70
commit ca3a729c36
18 changed files with 442 additions and 231 deletions

55
.golangci.yml Normal file
View file

@ -0,0 +1,55 @@
linters:
presets:
- bugs
- comment
- error
- format
- import
- metalinter
- module
- performance
- sql
- style
- test
- unused
disable:
- exhaustive
- exhaustivestruct
- gochecknoglobals
- gochecknoinits
- godox
- goerr113
- gomnd
- interfacer
- maligned
- nlreturn
- scopelint
- testpackage
- wrapcheck
linters-settings:
errcheck:
ignore: "database/sql:Rollback"
gci:
local-prefixes: github.com/dstotijn/hetty
godot:
capital: true
issues:
exclude-rules:
- linters:
- gosec
# Ignore SHA1 usage.
text: "G(401|505):"
- linters:
- wsl
# Ignore cuddled defer statements.
text: "only one cuddle assignment allowed before defer statement"
- linters:
- nlreturn
# Ignore `break` without leading blank line.
text: "break with no blank line before"
run:
skip-files:
- cmd/hetty/rice-box.go

View file

@ -2,25 +2,27 @@ package main
import (
"crypto/tls"
"errors"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground"
rice "github.com/GeertJohan/go.rice"
"github.com/gorilla/mux"
"github.com/mitchellh/go-homedir"
"github.com/dstotijn/hetty/pkg/api"
"github.com/dstotijn/hetty/pkg/db/sqlite"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/proxy"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground"
"github.com/gorilla/mux"
"github.com/mitchellh/go-homedir"
)
var (
@ -32,8 +34,16 @@ var (
)
func main() {
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem", "CA certificate filepath. Creates a new CA certificate if file doesn't exist")
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem", "CA private key filepath. Creates a new CA private key if file doesn't exist")
if err := run(); err != nil {
log.Fatalf("[ERROR]: %v", err)
}
}
func run() error {
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem",
"CA certificate filepath. Creates a new CA certificate if file doesn't exist")
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem",
"CA private key filepath. Creates a new CA private key if file doesn't exist")
flag.StringVar(&projPath, "projects", "~/.hetty/projects", "Projects directory path")
flag.StringVar(&addr, "addr", ":8080", "TCP address to listen on, in the form \"host:port\"")
flag.StringVar(&adminPath, "adminPath", "", "File path to admin build")
@ -42,32 +52,34 @@ func main() {
// Expand `~` in filepaths.
caCertFile, err := homedir.Expand(caCertFile)
if err != nil {
log.Fatalf("[FATAL] Could not parse CA certificate filepath: %v", err)
return fmt.Errorf("could not parse CA certificate filepath: %w", err)
}
caKeyFile, err := homedir.Expand(caKeyFile)
if err != nil {
log.Fatalf("[FATAL] Could not parse CA private key filepath: %v", err)
return fmt.Errorf("could not parse CA private key filepath: %w", err)
}
projPath, err := homedir.Expand(projPath)
if err != nil {
log.Fatalf("[FATAL] Could not parse projects filepath: %v", err)
return fmt.Errorf("could not parse projects filepath: %w", err)
}
// Load existing CA certificate and key from disk, or generate and write
// to disk if no files exist yet.
caCert, caKey, err := proxy.LoadOrCreateCA(caKeyFile, caCertFile)
if err != nil {
log.Fatalf("[FATAL] Could not create/load CA key pair: %v", err)
return fmt.Errorf("could not create/load CA key pair: %w", err)
}
db, err := sqlite.New(projPath)
if err != nil {
log.Fatalf("[FATAL] Could not initialize database client: %v", err)
return fmt.Errorf("could not initialize database client: %w", err)
}
projService, err := proj.NewService(db)
if err != nil {
log.Fatalf("[FATAL] Could not create new project service: %v", err)
return fmt.Errorf("could not create new project service: %w", err)
}
defer projService.Close()
@ -81,19 +93,21 @@ func main() {
p, err := proxy.NewProxy(caCert, caKey)
if err != nil {
log.Fatalf("[FATAL] Could not create Proxy: %v", err)
return fmt.Errorf("could not create proxy: %w", err)
}
p.UseRequestModifier(reqLogService.RequestModifier)
p.UseResponseModifier(reqLogService.ResponseModifier)
var adminHandler http.Handler
if adminPath == "" {
// Used for embedding with `rice`.
box, err := rice.FindBox("../../admin/dist")
if err != nil {
log.Fatalf("[FATAL] Could not find embedded admin resources: %v", err)
return fmt.Errorf("could not find embedded admin resources: %w", err)
}
adminHandler = http.FileServer(box.HTTPBox())
} else {
adminHandler = http.FileServer(http.Dir(adminPath))
@ -109,11 +123,12 @@ func main() {
// GraphQL server.
adminRouter.Path("/api/playground/").Handler(playground.Handler("GraphQL Playground", "/api/graphql/"))
adminRouter.Path("/api/graphql/").Handler(handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
RequestLogService: reqLogService,
ProjectService: projService,
ScopeService: scope,
}})))
adminRouter.Path("/api/graphql/").Handler(
handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
RequestLogService: reqLogService,
ProjectService: projService,
ScopeService: scope,
}})))
// Admin interface.
adminRouter.PathPrefix("").Handler(adminHandler)
@ -128,8 +143,11 @@ func main() {
}
log.Printf("[INFO] Running server on %v ...", addr)
err = s.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
log.Fatalf("[FATAL] HTTP server closed: %v", err)
if err != nil && errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("http server closed unexpected: %w", err)
}
return nil
}

View file

@ -4,16 +4,18 @@ package api
import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/dstotijn/hetty/pkg/search"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type Resolver struct {
@ -22,20 +24,22 @@ type Resolver struct {
ScopeService *scope.Scope
}
type queryResolver struct{ *Resolver }
type mutationResolver struct{ *Resolver }
type (
queryResolver struct{ *Resolver }
mutationResolver struct{ *Resolver }
)
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) {
reqs, err := r.RequestLogService.FindRequests(ctx)
if err == proj.ErrNoProject {
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
} else if err != nil {
return nil, fmt.Errorf("could not query repository for requests: %w", err)
}
if err != nil {
return nil, fmt.Errorf("could not query repository for requests: %v", err)
}
logs := make([]HTTPRequestLog, len(reqs))
for i, req := range reqs {
@ -43,6 +47,7 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
if err != nil {
return nil, err
}
logs[i] = req
}
@ -51,12 +56,12 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
if err == reqlog.ErrRequestNotFound {
if errors.Is(err, reqlog.ErrRequestNotFound) {
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("could not get request by ID: %w", err)
}
if err != nil {
return nil, fmt.Errorf("could not get request by ID: %v", err)
}
req, err := parseRequestLog(log)
if err != nil {
return nil, err
@ -89,6 +94,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
if req.Request.Header != nil {
log.Headers = make([]HTTPHeader, 0)
for key, values := range req.Request.Header {
for _, value := range values {
log.Headers = append(log.Headers, HTTPHeader{
@ -106,15 +112,19 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
StatusCode: req.Response.Response.StatusCode,
}
statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2)
if len(statusReasonSubs) == 2 {
log.Response.StatusReason = statusReasonSubs[1]
}
if len(req.Response.Body) > 0 {
resBody := string(req.Response.Body)
log.Response.Body = &resBody
}
if req.Response.Response.Header != nil {
log.Response.Headers = make([]HTTPHeader, 0)
for key, values := range req.Response.Response.Header {
for _, value := range values {
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
@ -131,12 +141,12 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
p, err := r.ProjectService.Open(ctx, name)
if err == proj.ErrInvalidName {
if errors.Is(err, proj.ErrInvalidName) {
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
} else if err != nil {
return nil, fmt.Errorf("could not open project: %w", err)
}
if err != nil {
return nil, fmt.Errorf("could not open project: %v", err)
}
return &Project{
Name: p.Name,
IsActive: p.IsActive,
@ -145,11 +155,10 @@ func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Proje
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
p, err := r.ProjectService.ActiveProject()
if err == proj.ErrNoProject {
if errors.Is(err, proj.ErrNoProject) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("could not open project: %v", err)
} else if err != nil {
return nil, fmt.Errorf("could not open project: %w", err)
}
return &Project{
@ -161,7 +170,7 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
p, err := r.ProjectService.Projects()
if err != nil {
return nil, fmt.Errorf("could not get projects: %v", err)
return nil, fmt.Errorf("could not get projects: %w", err)
}
projects := make([]Project, len(p))
@ -184,21 +193,25 @@ func regexpToStringPtr(r *regexp.Regexp) *string {
if r == nil {
return nil
}
s := r.String()
return &s
}
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
if err := r.ProjectService.Close(); err != nil {
return nil, fmt.Errorf("could not close project: %v", err)
return nil, fmt.Errorf("could not close project: %w", err)
}
return &CloseProjectResult{true}, nil
}
func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) {
if err := r.ProjectService.Delete(name); err != nil {
return nil, fmt.Errorf("could not delete project: %v", err)
return nil, fmt.Errorf("could not delete project: %w", err)
}
return &DeleteProjectResult{
Success: true,
}, nil
@ -206,33 +219,40 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del
func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) {
if err := r.RequestLogService.ClearRequests(ctx); err != nil {
return nil, fmt.Errorf("could not clear request log: %v", err)
return nil, fmt.Errorf("could not clear request log: %w", err)
}
return &ClearHTTPRequestLogResult{true}, nil
}
func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) {
rules := make([]scope.Rule, len(input))
for i, rule := range input {
u, err := stringPtrToRegexp(rule.URL)
if err != nil {
return nil, fmt.Errorf("invalid URL in scope rule: %v", err)
return nil, fmt.Errorf("invalid URL in scope rule: %w", err)
}
var headerKey, headerValue *regexp.Regexp
if rule.Header != nil {
headerKey, err = stringPtrToRegexp(rule.Header.Key)
if err != nil {
return nil, fmt.Errorf("invalid header key in scope rule: %v", err)
return nil, fmt.Errorf("invalid header key in scope rule: %w", err)
}
headerValue, err = stringPtrToRegexp(rule.Header.Key)
if err != nil {
return nil, fmt.Errorf("invalid header value in scope rule: %v", err)
return nil, fmt.Errorf("invalid header value in scope rule: %w", err)
}
}
body, err := stringPtrToRegexp(rule.Body)
if err != nil {
return nil, fmt.Errorf("invalid body in scope rule: %v", err)
return nil, fmt.Errorf("invalid body in scope rule: %w", err)
}
rules[i] = scope.Rule{
URL: u,
Header: scope.Header{
@ -244,7 +264,7 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput)
}
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
return nil, fmt.Errorf("could not set scope: %v", err)
return nil, fmt.Errorf("could not set scope: %w", err)
}
return scopeToScopeRules(rules), nil
@ -260,14 +280,14 @@ func (r *mutationResolver) SetHTTPRequestLogFilter(
) (*HTTPRequestLogFilter, error) {
filter, err := findRequestsFilterFromInput(input)
if err != nil {
return nil, fmt.Errorf("could not parse request log filter: %v", err)
return nil, fmt.Errorf("could not parse request log filter: %w", err)
}
err = r.RequestLogService.SetRequestLogFilter(ctx, filter)
if err == proj.ErrNoProject {
if errors.Is(err, proj.ErrNoProject) {
return nil, noActiveProjectErr(ctx)
}
if err != nil {
return nil, fmt.Errorf("could not set request log filter: %v", err)
} else if err != nil {
return nil, fmt.Errorf("could not set request log filter: %w", err)
}
return findReqFilterToHTTPReqLogFilter(filter), nil
@ -277,6 +297,7 @@ func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
if s == nil {
return nil, nil
}
return regexp.Compile(*s)
}
@ -290,8 +311,10 @@ func scopeToScopeRules(rules []scope.Rule) []ScopeRule {
Value: regexpToStringPtr(rule.Header.Value),
}
}
scopeRules[i].Body = regexpToStringPtr(rule.Body)
}
return scopeRules
}
@ -299,14 +322,17 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo
if input == nil {
return
}
if input.OnlyInScope != nil {
filter.OnlyInScope = *input.OnlyInScope
}
if input.SearchExpression != nil && *input.SearchExpression != "" {
expr, err := search.ParseQuery(*input.SearchExpression)
if err != nil {
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %v", err)
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err)
}
filter.RawSearchExpr = *input.SearchExpression
filter.SearchExpr = expr
}
@ -319,6 +345,7 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H
if findReqFilter == empty {
return nil
}
httpReqLogFilter := &HTTPRequestLogFilter{
OnlyInScope: findReqFilter.OnlyInScope,
}

View file

@ -43,7 +43,7 @@ func (u *reqURL) Scan(value interface{}) error {
parsed, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("sqlite: could not parse URL: %v", err)
return fmt.Errorf("sqlite: could not parse URL: %w", err)
}
*u = reqURL(*parsed)

View file

@ -6,6 +6,7 @@ import (
"sort"
sq "github.com/Masterminds/squirrel"
"github.com/dstotijn/hetty/pkg/search"
)
@ -57,20 +58,24 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
if err != nil {
return nil, err
}
right, err := parseSearchExpr(expr.Right)
if err != nil {
return nil, err
}
return sq.And{left, right}, nil
case search.TokOpOr:
left, err := parseSearchExpr(expr.Left)
if err != nil {
return nil, err
}
right, err := parseSearchExpr(expr.Right)
if err != nil {
return nil, err
}
return sq.Or{left, right}, nil
}
@ -78,6 +83,7 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
if !ok {
return nil, errors.New("left operand must be a string literal")
}
right, ok := expr.Right.(*search.StringLiteral)
if !ok {
return nil, errors.New("right operand must be a string literal")
@ -113,14 +119,17 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) {
// Sorting is not necessary, but makes it easier to do assertions in tests.
sortedKeys := make([]string, 0, len(stringLiteralMap))
for _, v := range stringLiteralMap {
sortedKeys = append(sortedKeys, v)
}
sort.Strings(sortedKeys)
or := make(sq.Or, len(stringLiteralMap))
for i, value := range sortedKeys {
or[i] = sq.Like{value: "%" + strLiteral.Value + "%"}
}
return or, nil
}

View file

@ -10,6 +10,8 @@ import (
)
func TestParseSearchExpr(t *testing.T) {
t.Parallel()
tests := []struct {
name string
searchExpr search.Expression
@ -206,6 +208,8 @@ func TestParseSearchExpr(t *testing.T) {
}
func assertError(t *testing.T, exp, got error) {
t.Helper()
switch {
case exp == nil && got != nil:
t.Fatalf("expected: nil, got: %v", got)

View file

@ -15,14 +15,14 @@ import (
"strings"
"time"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
"github.com/99designs/gqlgen/graphql"
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"github.com/dstotijn/hetty/pkg/proj"
"github.com/dstotijn/hetty/pkg/reqlog"
"github.com/dstotijn/hetty/pkg/scope"
)
var regexpFn = func(pattern string, value interface{}) (bool, error) {
@ -55,10 +55,7 @@ type httpRequestLogsQuery struct {
func init() {
sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
if err := conn.RegisterFunc("regexp", regexpFn, false); err != nil {
return err
}
return nil
return conn.RegisterFunc("regexp", regexpFn, false)
},
})
}
@ -66,9 +63,10 @@ func init() {
func New(dbPath string) (*Client, error) {
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
return nil, fmt.Errorf("proj: could not create project directory: %w", err)
}
}
return &Client{
dbPath: dbPath,
}, nil
@ -85,17 +83,18 @@ func (c *Client) OpenProject(name string) error {
dbPath := filepath.Join(c.dbPath, name+".db")
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
db, err := sqlx.Open("sqlite3_with_regexp", dsn)
if err != nil {
return fmt.Errorf("sqlite: could not open database: %v", err)
return fmt.Errorf("sqlite: could not open database: %w", err)
}
if err := db.Ping(); err != nil {
return fmt.Errorf("sqlite: could not ping database: %v", err)
return fmt.Errorf("sqlite: could not ping database: %w", err)
}
if err := prepareSchema(db); err != nil {
return fmt.Errorf("sqlite: could not prepare schema: %v", err)
return fmt.Errorf("sqlite: could not prepare schema: %w", err)
}
c.db = db
@ -107,10 +106,11 @@ func (c *Client) OpenProject(name string) error {
func (c *Client) Projects() ([]proj.Project, error) {
files, err := ioutil.ReadDir(c.dbPath)
if err != nil {
return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err)
return nil, fmt.Errorf("sqlite: could not read projects directory: %w", err)
}
projects := make([]proj.Project, len(files))
for i, file := range files {
projName := strings.TrimSuffix(file.Name(), ".db")
projects[i] = proj.Project{
@ -132,7 +132,7 @@ func prepareSchema(db *sqlx.DB) error {
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_requests table: %v", err)
return fmt.Errorf("could not create http_requests table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
@ -145,7 +145,7 @@ func prepareSchema(db *sqlx.DB) error {
timestamp DATETIME
)`)
if err != nil {
return fmt.Errorf("could not create http_responses table: %v", err)
return fmt.Errorf("could not create http_responses table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
@ -156,7 +156,7 @@ func prepareSchema(db *sqlx.DB) error {
value TEXT
)`)
if err != nil {
return fmt.Errorf("could not create http_headers table: %v", err)
return fmt.Errorf("could not create http_headers table: %w", err)
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
@ -164,7 +164,7 @@ func prepareSchema(db *sqlx.DB) error {
settings TEXT
)`)
if err != nil {
return fmt.Errorf("could not create settings table: %v", err)
return fmt.Errorf("could not create settings table: %w", err)
}
return nil
@ -175,8 +175,9 @@ func (c *Client) Close() error {
if c.db == nil {
return nil
}
if err := c.db.Close(); err != nil {
return fmt.Errorf("sqlite: could not close database: %v", err)
return fmt.Errorf("sqlite: could not close database: %w", err)
}
c.db = nil
@ -187,7 +188,7 @@ func (c *Client) Close() error {
func (c *Client) DeleteProject(name string) error {
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
return fmt.Errorf("sqlite: could not remove database file: %v", err)
return fmt.Errorf("sqlite: could not remove database file: %w", err)
}
return nil
@ -219,10 +220,12 @@ func (c *Client) ClearRequestLogs(ctx context.Context) error {
if c.db == nil {
return proj.ErrNoProject
}
_, err := c.db.Exec("DELETE FROM http_requests")
if err != nil {
return fmt.Errorf("sqlite: could not delete requests: %v", err)
return fmt.Errorf("sqlite: could not delete requests: %w", err)
}
return nil
}
@ -247,11 +250,13 @@ func (c *Client) FindRequestLogs(
if filter.OnlyInScope && scope != nil {
var ruleExpr []sq.Sqlizer
for _, rule := range scope.Rules() {
if rule.URL != nil {
ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String()))
}
}
if len(ruleExpr) > 0 {
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
}
@ -260,37 +265,41 @@ func (c *Client) FindRequestLogs(
if filter.SearchExpr != nil {
sqlizer, err := parseSearchExpr(filter.SearchExpr)
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse search expression: %v", err)
return nil, fmt.Errorf("sqlite: could not parse search expression: %w", err)
}
reqQuery = reqQuery.Where(sqlizer)
}
sql, args, err := reqQuery.ToSql()
if err != nil {
return nil, fmt.Errorf("sqlite: could not parse query: %v", err)
return nil, fmt.Errorf("sqlite: could not parse query: %w", err)
}
rows, err := c.db.QueryxContext(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
}
defer rows.Close()
for rows.Next() {
var dto httpRequest
err = rows.StructScan(&dto)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
}
reqLogs = append(reqLogs, dto.toRequestLog())
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
}
rows.Close()
defer rows.Close()
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
return nil, fmt.Errorf("sqlite: could not query headers: %v", err)
return nil, fmt.Errorf("sqlite: could not query headers: %w", err)
}
return reqLogs, nil
@ -300,35 +309,38 @@ func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Reque
if c.db == nil {
return reqlog.Request{}, proj.ErrNoProject
}
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
reqQuery := sq.
Select(httpReqLogsQuery.requestCols...).
From("http_requests req").
Where("req.id = ?")
if httpReqLogsQuery.joinResponse {
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
}
reqSQL, _, err := reqQuery.ToSql()
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %v", err)
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %w", err)
}
row := c.db.QueryRowxContext(ctx, reqSQL, id)
var dto httpRequest
err = row.StructScan(&dto)
if err == sql.ErrNoRows {
return reqlog.Request{}, reqlog.ErrRequestNotFound
}
if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %v", err)
}
reqLog := dto.toRequestLog()
var dto httpRequest
err = row.StructScan(&dto)
if errors.Is(err, sql.ErrNoRows) {
return reqlog.Request{}, reqlog.ErrRequestNotFound
} else if err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %w", err)
}
reqLog := dto.toRequestLog()
reqLogs := []reqlog.Request{reqLog}
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %v", err)
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %w", err)
}
return reqLogs[0], nil
@ -352,8 +364,9 @@ func (c *Client) AddRequestLog(
tx, err := c.db.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
}
defer tx.Rollback()
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
@ -364,7 +377,7 @@ func (c *Client) AddRequestLog(
timestamp
) VALUES (?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer reqStmt.Close()
@ -376,13 +389,14 @@ func (c *Client) AddRequestLog(
reqLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
}
reqID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
}
reqLog.ID = reqID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
@ -391,17 +405,17 @@ func (c *Client) AddRequestLog(
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
}
return reqLog, nil
@ -424,9 +438,10 @@ func (c *Client) AddResponseLog(
Body: body,
Timestamp: timestamp,
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
}
defer tx.Rollback()
@ -439,7 +454,7 @@ func (c *Client) AddResponseLog(
timestamp
) VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer resStmt.Close()
@ -457,13 +472,14 @@ func (c *Client) AddResponseLog(
resLog.Timestamp,
)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
}
resID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
}
resLog.ID = resID
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
@ -472,17 +488,17 @@ func (c *Client) AddResponseLog(
value
) VALUES (?, ?, ?)`)
if err != nil {
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
}
defer headerStmt.Close()
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
if err != nil {
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
}
return resLog, nil
@ -496,14 +512,14 @@ func (c *Client) UpsertSettings(ctx context.Context, module string, settings int
jsonSettings, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err)
return fmt.Errorf("sqlite: could not encode settings as JSON: %w", err)
}
_, err = c.db.ExecContext(ctx,
`INSERT INTO settings (module, settings) VALUES (?, ?)
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
if err != nil {
return fmt.Errorf("sqlite: could not insert scope settings: %v", err)
return fmt.Errorf("sqlite: could not insert scope settings: %w", err)
}
return nil
@ -515,17 +531,18 @@ func (c *Client) FindSettingsByModule(ctx context.Context, module string, settin
}
var jsonSettings []byte
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
err := row.Scan(&jsonSettings)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return proj.ErrNoSettings
}
if err != nil {
return fmt.Errorf("sqlite: could not scan row: %v", err)
} else if err != nil {
return fmt.Errorf("sqlite: could not scan row: %w", err)
}
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err)
return fmt.Errorf("sqlite: could not decode settings from JSON: %w", err)
}
return nil
@ -535,42 +552,46 @@ func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.H
for key, values := range headers {
for _, value := range values {
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
return fmt.Errorf("could not execute statement: %v", err)
return fmt.Errorf("could not execute statement: %w", err)
}
}
}
return nil
}
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
headers := make(http.Header)
rows, err := stmt.QueryContext(ctx, id)
if err != nil {
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
}
defer rows.Close()
for rows.Next() {
var key, value string
err := rows.Scan(
&key,
&value,
)
err := rows.Scan(&key, &value)
if err != nil {
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
}
headers.Add(key, value)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
}
return headers, nil
}
func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
var joinResponse bool
var reqHeaderCols, resHeaderCols []string
var (
joinResponse bool
reqHeaderCols, resHeaderCols []string
)
opCtx := graphql.GetOperationContext(ctx)
reqFields := graphql.CollectFieldsCtx(ctx, nil)
@ -580,6 +601,7 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
if col, ok := reqFieldToColumnMap[reqField.Name]; ok {
reqCols = append(reqCols, "req."+col)
}
if reqField.Name == "headers" {
headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
for _, headerField := range headerFields {
@ -588,19 +610,23 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
}
}
}
if reqField.Name == "response" {
joinResponse = true
resFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
for _, resField := range resFields {
if resField.Name == "headers" {
reqCols = append(reqCols, "res.id AS res_id")
headerFields := graphql.CollectFields(opCtx, resField.Selections, nil)
for _, headerField := range headerFields {
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
resHeaderCols = append(resHeaderCols, col)
}
}
}
if col, ok := resFieldToColumnMap[resField.Name]; ok {
reqCols = append(reqCols, "res."+col)
}
@ -627,18 +653,21 @@ func (c *Client) queryHeaders(
From("http_headers").Where("req_id = ?").
ToSql()
if err != nil {
return fmt.Errorf("could not parse request headers query: %v", err)
return fmt.Errorf("could not parse request headers query: %w", err)
}
reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery)
if err != nil {
return fmt.Errorf("could not prepare statement: %v", err)
return fmt.Errorf("could not prepare statement: %w", err)
}
defer reqHeadersStmt.Close()
for i := range reqLogs {
headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID)
if err != nil {
return fmt.Errorf("could not query request headers: %v", err)
return fmt.Errorf("could not query request headers: %w", err)
}
reqLogs[i].Request.Header = headers
}
}
@ -649,21 +678,25 @@ func (c *Client) queryHeaders(
From("http_headers").Where("res_id = ?").
ToSql()
if err != nil {
return fmt.Errorf("could not parse response headers query: %v", err)
return fmt.Errorf("could not parse response headers query: %w", err)
}
resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery)
if err != nil {
return fmt.Errorf("could not prepare statement: %v", err)
return fmt.Errorf("could not prepare statement: %w", err)
}
defer resHeadersStmt.Close()
for i := range reqLogs {
if reqLogs[i].Response == nil {
continue
}
headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID)
if err != nil {
return fmt.Errorf("could not query response headers: %v", err)
return fmt.Errorf("could not query response headers: %w", err)
}
reqLogs[i].Response.Response.Header = headers
}
}

View file

@ -9,8 +9,10 @@ import (
"sync"
)
type OnProjectOpenFn func(name string) error
type OnProjectCloseFn func(name string) error
type (
OnProjectOpenFn func(name string) error
OnProjectCloseFn func(name string) error
)
// Service is used for managing projects.
type Service struct {
@ -47,8 +49,9 @@ func (svc *Service) Close() error {
defer svc.mu.Unlock()
closedProject := svc.activeProject
if err := svc.repo.Close(); err != nil {
return fmt.Errorf("proj: could not close project: %v", err)
return fmt.Errorf("proj: could not close project: %w", err)
}
svc.activeProject = ""
@ -63,12 +66,13 @@ func (svc *Service) Delete(name string) error {
if name == "" {
return errors.New("proj: name cannot be empty")
}
if svc.activeProject == name {
return fmt.Errorf("proj: project (%v) is active", name)
}
if err := svc.repo.DeleteProject(name); err != nil {
return fmt.Errorf("proj: could not delete project: %v", err)
return fmt.Errorf("proj: could not delete project: %w", err)
}
return nil
@ -85,11 +89,11 @@ func (svc *Service) Open(ctx context.Context, name string) (Project, error) {
defer svc.mu.Unlock()
if err := svc.repo.Close(); err != nil {
return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err)
return Project{}, fmt.Errorf("proj: could not close previously open database: %w", err)
}
if err := svc.repo.OpenProject(name); err != nil {
return Project{}, fmt.Errorf("proj: could not open database: %v", err)
return Project{}, fmt.Errorf("proj: could not open database: %w", err)
}
svc.activeProject = name
@ -115,7 +119,7 @@ func (svc *Service) ActiveProject() (Project, error) {
func (svc *Service) Projects() ([]Project, error) {
projects, err := svc.repo.Projects()
if err != nil {
return nil, fmt.Errorf("proj: could not get projects: %v", err)
return nil, fmt.Errorf("proj: could not get projects: %w", err)
}
return projects, nil

View file

@ -25,7 +25,7 @@ import (
var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
// CertConfig is a set of configuration values that are used to build TLS configs
// capable of MITM
// capable of MITM.
type CertConfig struct {
ca *x509.Certificate
caPriv crypto.PrivateKey
@ -40,6 +40,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
if err != nil {
return nil, err
}
pub := priv.Public()
// Subject Key Identifier support for end entity certificate.
@ -48,6 +49,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
if err != nil {
return nil, err
}
h := sha1.New()
h.Write(pkixPubKey)
keyID := h.Sum(nil)
@ -67,58 +69,69 @@ func LoadOrCreateCA(caKeyFile, caCertFile string) (*x509.Certificate, *rsa.Priva
if err == nil {
caCert, err := x509.ParseCertificate(tlsCA.Certificate[0])
if err != nil {
return nil, nil, fmt.Errorf("proxy: could not parse CA: %v", err)
return nil, nil, fmt.Errorf("proxy: could not parse CA: %w", err)
}
caKey, ok := tlsCA.PrivateKey.(*rsa.PrivateKey)
if !ok {
return nil, nil, errors.New("proxy: private key is not RSA")
}
return caCert, caKey, nil
}
if !os.IsNotExist(err) {
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %v", err)
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %w", err)
}
// Create directories for files if they don't exist yet.
keyDir, _ := filepath.Split(caKeyFile)
if keyDir != "" {
if _, err := os.Stat(keyDir); os.IsNotExist(err) {
os.MkdirAll(keyDir, 0755)
if err := os.MkdirAll(keyDir, 0755); err != nil {
return nil, nil, fmt.Errorf("proxy: could not create directory for CA key: %w", err)
}
}
}
keyDir, _ = filepath.Split(caCertFile)
if keyDir != "" {
if _, err := os.Stat("keyDir"); os.IsNotExist(err) {
os.MkdirAll(keyDir, 0755)
if err := os.MkdirAll(keyDir, 0755); err != nil {
return nil, nil, fmt.Errorf("proxy: could not create directory for CA cert: %w", err)
}
}
}
// Create new CA keypair.
caCert, caKey, err := NewCA("Hetty", "Hetty CA", time.Duration(365*24*time.Hour))
caCert, caKey, err := NewCA("Hetty", "Hetty CA", 365*24*time.Hour)
if err != nil {
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %v", err)
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %w", err)
}
// Open CA certificate and key files for writing.
certOut, err := os.Create(caCertFile)
if err != nil {
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %v", err)
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %w", err)
}
keyOut, err := os.OpenFile(caKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %v", err)
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %w", err)
}
// Write PEM blocks to CA certificate and key files.
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}); err != nil {
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %v", err)
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %w", err)
}
privBytes, err := x509.MarshalPKCS8PrivateKey(caKey)
if err != nil {
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %v", err)
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %w", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %v", err)
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %w", err)
}
return caCert, caKey, nil
@ -130,6 +143,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
if err != nil {
return nil, nil, err
}
pub := priv.Public()
// Subject Key Identifier support for end entity certificate.
@ -138,6 +152,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
if err != nil {
return nil, nil, err
}
h := sha1.New()
h.Write(pkixpub)
keyID := h.Sum(nil)
@ -187,8 +202,10 @@ func (c *CertConfig) TLSConfig() *tls.Config {
if clientHello.ServerName == "" {
return nil, errors.New("missing server name (SNI)")
}
return c.cert(clientHello.ServerName)
},
MinVersion: tls.VersionTLS12,
NextProtos: []string{"http/1.1"},
}
}

View file

@ -5,13 +5,12 @@ import (
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
"github.com/dstotijn/hetty/pkg/scope"
)
type contextKey int
@ -27,8 +26,6 @@ type Proxy struct {
// TODO: Add mutex for modifier funcs.
reqModifiers []RequestModifyMiddleware
resModifiers []ResponseModifyMiddleware
scope *scope.Scope
}
// NewProxy returns a new Proxy.
@ -55,7 +52,7 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
p.handleConnect(w)
return
}
@ -103,11 +100,12 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
// During the TLS handshake with the client, we use the proxy's CA config to
// create a certificate on-the-fly.
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) handleConnect(w http.ResponseWriter) {
hj, ok := w.(http.Hijacker)
if !ok {
log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w)
writeError(w, r, http.StatusServiceUnavailable)
writeError(w, http.StatusServiceUnavailable)
return
}
@ -116,7 +114,8 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
clientConn, _, err := hj.Hijack()
if err != nil {
log.Printf("[ERROR] Hijacking client connection failed: %v", err)
writeError(w, r, http.StatusServiceUnavailable)
writeError(w, http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
@ -127,14 +126,15 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
log.Printf("[ERROR] Securing client connection failed: %v", err)
return
}
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
l := &OnceAcceptListener{clientConnNotify.Conn}
err = http.Serve(l, p)
if err != nil && err != ErrAlreadyAccepted {
if err != nil && !errors.Is(err, ErrAlreadyAccepted) {
log.Printf("[ERROR] Serving HTTP request failed: %v", err)
}
<-clientConnNotify.closed
}
@ -144,20 +144,22 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("handshake error: %v", err)
return nil, fmt.Errorf("handshake error: %w", err)
}
return tlsConn, nil
}
func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
if err == context.Canceled {
if errors.Is(err, context.Canceled) {
return
}
log.Printf("[ERROR]: Proxy error: %v", err)
w.WriteHeader(http.StatusBadGateway)
}
func writeError(w http.ResponseWriter, r *http.Request, code int) {
func writeError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}

View file

@ -16,7 +16,7 @@ type Repository interface {
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error)
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) // nolint:lll
ClearRequestLogs(ctx context.Context) error
UpsertSettings(ctx context.Context, module string, settings interface{}) error
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error

View file

@ -24,9 +24,7 @@ const LogBypassedKey contextKey = 0
const moduleName = "reqlog"
var (
ErrRequestNotFound = errors.New("reqlog: request not found")
)
var ErrRequestNotFound = errors.New("reqlog: request not found")
type Request struct {
ID int64
@ -74,12 +72,13 @@ func NewService(cfg Config) *Service {
cfg.ProjectService.OnProjectOpen(func(_ string) error {
err := svc.loadSettings()
if err == proj.ErrNoSettings {
if errors.Is(err, proj.ErrNoSettings) {
return nil
}
if err != nil {
return fmt.Errorf("reqlog: could not load settings: %v", err)
return fmt.Errorf("reqlog: could not load settings: %w", err)
}
return nil
})
cfg.ProjectService.OnProjectClose(func(_ string) error {
@ -126,12 +125,13 @@ func (svc *Service) addResponse(
if res.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err)
return nil, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
}
defer gzipReader.Close()
body, err = ioutil.ReadAll(gzipReader)
if err != nil {
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
}
}
@ -141,18 +141,23 @@ func (svc *Service) addResponse(
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
return func(req *http.Request) {
now := time.Now()
next(req)
clone := req.Clone(req.Context())
var body []byte
if req.Body != nil {
// TODO: Use io.LimitReader.
var err error
body, err = ioutil.ReadAll(req.Body)
if err != nil {
log.Printf("[ERROR] Could not read request body for logging: %v", err)
return
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
@ -161,19 +166,21 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
return
}
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
if err == proj.ErrNoProject {
if errors.Is(err, proj.ErrNoProject) {
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
*req = *req.WithContext(ctx)
return
}
if err != nil {
} else if err != nil {
log.Printf("[ERROR] Could not store request log: %v", err)
return
}
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
*req = *req.WithContext(ctx)
}
@ -182,6 +189,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
return func(res *http.Response) error {
now := time.Now()
if err := next(res); err != nil {
return err
}
@ -200,8 +208,9 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
// TODO: Use io.LimitReader.
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("reqlog: could not read response body: %v", err)
return fmt.Errorf("reqlog: could not read response body: %w", err)
}
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
go func() {
@ -220,6 +229,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
OnlyInScope bool
RawSearchExpr string
}
if err := json.Unmarshal(b, &dto); err != nil {
return err
}
@ -234,6 +244,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
filter.SearchExpr = expr
}

View file

@ -3,6 +3,7 @@ package scope
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"regexp"
@ -38,12 +39,13 @@ func New(repo Repository, projService *proj.Service) *Scope {
projService.OnProjectOpen(func(_ string) error {
err := s.load(context.Background())
if err == proj.ErrNoSettings {
if errors.Is(err, proj.ErrNoSettings) {
return nil
}
if err != nil {
return fmt.Errorf("scope: could not load scope: %v", err)
return fmt.Errorf("scope: could not load scope: %w", err)
}
return nil
})
projService.OnProjectClose(func(_ string) error {
@ -57,6 +59,7 @@ func New(repo Repository, projService *proj.Service) *Scope {
func (s *Scope) Rules() []Rule {
s.mu.RLock()
defer s.mu.RUnlock()
return s.rules
}
@ -65,12 +68,12 @@ func (s *Scope) load(ctx context.Context) error {
defer s.mu.Unlock()
var rules []Rule
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
if err == proj.ErrNoSettings {
if errors.Is(err, proj.ErrNoSettings) {
return err
}
if err != nil {
return fmt.Errorf("scope: could not load scope settings: %v", err)
} else if err != nil {
return fmt.Errorf("scope: could not load scope settings: %w", err)
}
s.rules = rules
@ -89,7 +92,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
defer s.mu.Unlock()
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
return fmt.Errorf("scope: cannot set rules in repository: %v", err)
return fmt.Errorf("scope: cannot set rules in repository: %w", err)
}
s.rules = rules
@ -100,6 +103,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
func (s *Scope) Match(req *http.Request, body []byte) bool {
s.mu.RLock()
defer s.mu.RUnlock()
for _, rule := range s.rules {
if matches := rule.Match(req, body); matches {
return true
@ -118,11 +122,13 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
for key, values := range req.Header {
var keyMatches, valueMatches bool
if r.Header.Key != nil {
if matches := r.Header.Key.MatchString(key); matches {
keyMatches = true
}
}
if r.Header.Value != nil {
for _, value := range values {
if matches := r.Header.Value.MatchString(value); matches {
@ -154,15 +160,17 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
// MarshalJSON implements json.Marshaler.
func (r Rule) MarshalJSON() ([]byte, error) {
type headerDTO struct {
Key string
Value string
}
type ruleDTO struct {
URL string
Header headerDTO
Body string
}
type (
headerDTO struct {
Key string
Value string
}
ruleDTO struct {
URL string
Header headerDTO
Body string
}
)
dto := ruleDTO{
URL: regexpToString(r.URL),
@ -178,15 +186,17 @@ func (r Rule) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (r *Rule) UnmarshalJSON(data []byte) error {
type headerDTO struct {
Key string
Value string
}
type ruleDTO struct {
URL string
Header headerDTO
Body string
}
type (
headerDTO struct {
Key string
Value string
}
ruleDTO struct {
URL string
Header headerDTO
Body string
}
)
var dto ruleDTO
if err := json.Unmarshal(data, &dto); err != nil {
@ -197,14 +207,17 @@ func (r *Rule) UnmarshalJSON(data []byte) error {
if err != nil {
return err
}
headerKey, err := stringToRegexp(dto.Header.Key)
if err != nil {
return err
}
headerValue, err := stringToRegexp(dto.Header.Value)
if err != nil {
return err
}
body, err := stringToRegexp(dto.Body)
if err != nil {
return err
@ -226,6 +239,7 @@ func regexpToString(r *regexp.Regexp) string {
if r == nil {
return ""
}
return r.String()
}
@ -233,5 +247,6 @@ func stringToRegexp(s string) (*regexp.Regexp, error) {
if s == "" {
return nil, nil
}
return regexp.Compile(s)
}

View file

@ -11,7 +11,6 @@ type PrefixExpression struct {
Right Expression
}
func (pe *PrefixExpression) expressionNode() {}
func (pe *PrefixExpression) String() string {
b := strings.Builder{}
b.WriteString("(")
@ -29,7 +28,6 @@ type InfixExpression struct {
Right Expression
}
func (ie *InfixExpression) expressionNode() {}
func (ie *InfixExpression) String() string {
b := strings.Builder{}
b.WriteString("(")
@ -47,7 +45,6 @@ type StringLiteral struct {
Value string
}
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string {
return sl.Value
}

View file

@ -17,21 +17,21 @@ const eof = 0
// Token types.
const (
// Flow
// Flow.
TokInvalid TokenType = iota
TokEOF
TokParenOpen
TokParenClose
// Literals
// Literals.
TokString
// Boolean operators
// Boolean operators.
TokOpNot
TokOpAnd
TokOpOr
// Comparison operators
// Comparison operators.
TokOpEq
TokOpNotEq
TokOpGt
@ -98,6 +98,7 @@ func (tt TokenType) String() string {
if typeString, ok := tokenTypeStrings[tt]; ok {
return typeString
}
return "<unknown>"
}
@ -113,6 +114,7 @@ func (l *Lexer) read() (r rune) {
l.width = 0
return eof
}
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
l.pos += l.width
@ -124,6 +126,7 @@ func (l *Lexer) emit(tokenType TokenType) {
Type: tokenType,
Literal: l.input[l.start:l.pos],
}
l.start = l.pos
}
@ -159,6 +162,7 @@ func begin(l *Lexer) stateFn {
l.backup()
l.emit(TokOpEq)
}
return begin
case '!':
switch next := l.read(); next {
@ -169,6 +173,7 @@ func begin(l *Lexer) stateFn {
default:
return l.errorf("invalid rune %v", r)
}
return begin
case '<':
if next := l.read(); next == '=' {
@ -177,6 +182,7 @@ func begin(l *Lexer) stateFn {
l.backup()
l.emit(TokOpLt)
}
return begin
case '>':
if next := l.read(); next == '=' {
@ -185,6 +191,7 @@ func begin(l *Lexer) stateFn {
l.backup()
l.emit(TokOpGt)
}
return begin
case '(':
l.emit(TokParenOpen)
@ -231,15 +238,18 @@ func unquotedString(l *Lexer) stateFn {
case r == eof:
l.backup()
l.emitUnquotedString()
return begin
case unicode.IsSpace(r):
l.backup()
l.emitUnquotedString()
l.skip()
return begin
case isReserved(r):
l.backup()
l.emitUnquotedString()
return begin
}
}
@ -251,6 +261,7 @@ func (l *Lexer) emitUnquotedString() {
l.emit(tokType)
return
}
l.emit(TokString)
}
@ -260,5 +271,6 @@ func isReserved(r rune) bool {
return true
}
}
return false
}

View file

@ -3,6 +3,8 @@ package search
import "testing"
func TestNextToken(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string

View file

@ -18,8 +18,10 @@ const (
precGroup
)
type prefixParser func(*Parser) (Expression, error)
type infixParser func(*Parser, Expression) (Expression, error)
type (
prefixParser func(*Parser) (Expression, error)
infixParser func(*Parser, Expression) (Expression, error)
)
var (
prefixParsers = map[TokenType]prefixParser{}
@ -77,7 +79,6 @@ func NewParser(l *Lexer) *Parser {
p.nextToken()
return p
}
func ParseQuery(input string) (expr Expression, err error) {
@ -91,18 +92,20 @@ func ParseQuery(input string) (expr Expression, err error) {
for !p.curTokenIs(TokEOF) {
right, err := p.parseExpression(precLowest)
if err != nil {
return nil, fmt.Errorf("search: could not parse expression: %v", err)
}
if expr == nil {
switch {
case err != nil:
return nil, fmt.Errorf("search: could not parse expression: %w", err)
case expr == nil:
expr = right
} else {
default:
expr = &InfixExpression{
Operator: TokOpAnd,
Left: expr,
Right: right,
}
}
p.nextToken()
}
@ -122,18 +125,11 @@ func (p *Parser) peekTokenIs(t TokenType) bool {
return p.peek.Type == t
}
func (p *Parser) expectPeek(t TokenType) error {
if !p.peekTokenIs(t) {
return fmt.Errorf("expected next token to be %v, got %v", t, p.peek.Type)
}
p.nextToken()
return nil
}
func (p *Parser) curPrecedence() precedence {
if p, ok := tokenPrecedences[p.cur.Type]; ok {
return p
}
return precLowest
}
@ -141,6 +137,7 @@ func (p *Parser) peekPrecedence() precedence {
if p, ok := tokenPrecedences[p.peek.Type]; ok {
return p
}
return precLowest
}
@ -152,7 +149,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
expr, err := prefixParser(p)
if err != nil {
return nil, fmt.Errorf("could not parse expression prefix: %v", err)
return nil, fmt.Errorf("could not parse expression prefix: %w", err)
}
for !p.peekTokenIs(eof) && prec < p.peekPrecedence() {
@ -165,7 +162,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
expr, err = infixParser(p, expr)
if err != nil {
return nil, fmt.Errorf("could not parse infix expression: %v", err)
return nil, fmt.Errorf("could not parse infix expression: %w", err)
}
}
@ -181,8 +178,9 @@ func parsePrefixExpression(p *Parser) (Expression, error) {
right, err := p.parseExpression(precPrefix)
if err != nil {
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
}
expr.Right = right
return expr, nil
@ -199,8 +197,9 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
right, err := p.parseExpression(prec)
if err != nil {
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
}
expr.Right = right
return expr, nil
@ -215,17 +214,19 @@ func parseGroupedExpression(p *Parser) (Expression, error) {
expr, err := p.parseExpression(precLowest)
if err != nil {
return nil, fmt.Errorf("could not parse grouped expression: %v", err)
return nil, fmt.Errorf("could not parse grouped expression: %w", err)
}
for p.nextToken(); !p.curTokenIs(TokParenClose); p.nextToken() {
if p.curTokenIs(TokEOF) {
return nil, fmt.Errorf("unexpected EOF: unmatched parentheses")
}
right, err := p.parseExpression(precLowest)
if err != nil {
return nil, fmt.Errorf("could not parse expression: %v", err)
return nil, fmt.Errorf("could not parse expression: %w", err)
}
expr = &InfixExpression{
Operator: TokOpAnd,
Left: expr,

View file

@ -7,6 +7,8 @@ import (
)
func TestParseQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
@ -233,6 +235,8 @@ func TestParseQuery(t *testing.T) {
}
func assertError(t *testing.T, exp, got error) {
t.Helper()
switch {
case exp == nil && got != nil:
t.Fatalf("expected: nil, got: %v", got)