mirror of
https://github.com/dstotijn/hetty
synced 2024-11-10 06:04:19 +00:00
Add linter, fix linting issue
This commit is contained in:
parent
ad3dc0da70
commit
ca3a729c36
18 changed files with 442 additions and 231 deletions
55
.golangci.yml
Normal file
55
.golangci.yml
Normal file
|
@ -0,0 +1,55 @@
|
|||
linters:
|
||||
presets:
|
||||
- bugs
|
||||
- comment
|
||||
- error
|
||||
- format
|
||||
- import
|
||||
- metalinter
|
||||
- module
|
||||
- performance
|
||||
- sql
|
||||
- style
|
||||
- test
|
||||
- unused
|
||||
disable:
|
||||
- exhaustive
|
||||
- exhaustivestruct
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- godox
|
||||
- goerr113
|
||||
- gomnd
|
||||
- interfacer
|
||||
- maligned
|
||||
- nlreturn
|
||||
- scopelint
|
||||
- testpackage
|
||||
- wrapcheck
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
ignore: "database/sql:Rollback"
|
||||
gci:
|
||||
local-prefixes: github.com/dstotijn/hetty
|
||||
godot:
|
||||
capital: true
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
- linters:
|
||||
- gosec
|
||||
# Ignore SHA1 usage.
|
||||
text: "G(401|505):"
|
||||
- linters:
|
||||
- wsl
|
||||
# Ignore cuddled defer statements.
|
||||
text: "only one cuddle assignment allowed before defer statement"
|
||||
- linters:
|
||||
- nlreturn
|
||||
# Ignore `break` without leading blank line.
|
||||
text: "break with no blank line before"
|
||||
|
||||
run:
|
||||
skip-files:
|
||||
- cmd/hetty/rice-box.go
|
|
@ -2,25 +2,27 @@ package main
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler"
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
rice "github.com/GeertJohan/go.rice"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/api"
|
||||
"github.com/dstotijn/hetty/pkg/db/sqlite"
|
||||
"github.com/dstotijn/hetty/pkg/proj"
|
||||
"github.com/dstotijn/hetty/pkg/proxy"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler"
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -32,8 +34,16 @@ var (
|
|||
)
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem", "CA certificate filepath. Creates a new CA certificate if file doesn't exist")
|
||||
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem", "CA private key filepath. Creates a new CA private key if file doesn't exist")
|
||||
if err := run(); err != nil {
|
||||
log.Fatalf("[ERROR]: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem",
|
||||
"CA certificate filepath. Creates a new CA certificate if file doesn't exist")
|
||||
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem",
|
||||
"CA private key filepath. Creates a new CA private key if file doesn't exist")
|
||||
flag.StringVar(&projPath, "projects", "~/.hetty/projects", "Projects directory path")
|
||||
flag.StringVar(&addr, "addr", ":8080", "TCP address to listen on, in the form \"host:port\"")
|
||||
flag.StringVar(&adminPath, "adminPath", "", "File path to admin build")
|
||||
|
@ -42,32 +52,34 @@ func main() {
|
|||
// Expand `~` in filepaths.
|
||||
caCertFile, err := homedir.Expand(caCertFile)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not parse CA certificate filepath: %v", err)
|
||||
return fmt.Errorf("could not parse CA certificate filepath: %w", err)
|
||||
}
|
||||
|
||||
caKeyFile, err := homedir.Expand(caKeyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not parse CA private key filepath: %v", err)
|
||||
return fmt.Errorf("could not parse CA private key filepath: %w", err)
|
||||
}
|
||||
|
||||
projPath, err := homedir.Expand(projPath)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not parse projects filepath: %v", err)
|
||||
return fmt.Errorf("could not parse projects filepath: %w", err)
|
||||
}
|
||||
|
||||
// Load existing CA certificate and key from disk, or generate and write
|
||||
// to disk if no files exist yet.
|
||||
caCert, caKey, err := proxy.LoadOrCreateCA(caKeyFile, caCertFile)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not create/load CA key pair: %v", err)
|
||||
return fmt.Errorf("could not create/load CA key pair: %w", err)
|
||||
}
|
||||
|
||||
db, err := sqlite.New(projPath)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not initialize database client: %v", err)
|
||||
return fmt.Errorf("could not initialize database client: %w", err)
|
||||
}
|
||||
|
||||
projService, err := proj.NewService(db)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not create new project service: %v", err)
|
||||
return fmt.Errorf("could not create new project service: %w", err)
|
||||
}
|
||||
defer projService.Close()
|
||||
|
||||
|
@ -81,19 +93,21 @@ func main() {
|
|||
|
||||
p, err := proxy.NewProxy(caCert, caKey)
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not create Proxy: %v", err)
|
||||
return fmt.Errorf("could not create proxy: %w", err)
|
||||
}
|
||||
|
||||
p.UseRequestModifier(reqLogService.RequestModifier)
|
||||
p.UseResponseModifier(reqLogService.ResponseModifier)
|
||||
|
||||
var adminHandler http.Handler
|
||||
|
||||
if adminPath == "" {
|
||||
// Used for embedding with `rice`.
|
||||
box, err := rice.FindBox("../../admin/dist")
|
||||
if err != nil {
|
||||
log.Fatalf("[FATAL] Could not find embedded admin resources: %v", err)
|
||||
return fmt.Errorf("could not find embedded admin resources: %w", err)
|
||||
}
|
||||
|
||||
adminHandler = http.FileServer(box.HTTPBox())
|
||||
} else {
|
||||
adminHandler = http.FileServer(http.Dir(adminPath))
|
||||
|
@ -109,11 +123,12 @@ func main() {
|
|||
|
||||
// GraphQL server.
|
||||
adminRouter.Path("/api/playground/").Handler(playground.Handler("GraphQL Playground", "/api/graphql/"))
|
||||
adminRouter.Path("/api/graphql/").Handler(handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
|
||||
RequestLogService: reqLogService,
|
||||
ProjectService: projService,
|
||||
ScopeService: scope,
|
||||
}})))
|
||||
adminRouter.Path("/api/graphql/").Handler(
|
||||
handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
|
||||
RequestLogService: reqLogService,
|
||||
ProjectService: projService,
|
||||
ScopeService: scope,
|
||||
}})))
|
||||
|
||||
// Admin interface.
|
||||
adminRouter.PathPrefix("").Handler(adminHandler)
|
||||
|
@ -128,8 +143,11 @@ func main() {
|
|||
}
|
||||
|
||||
log.Printf("[INFO] Running server on %v ...", addr)
|
||||
|
||||
err = s.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("[FATAL] HTTP server closed: %v", err)
|
||||
if err != nil && errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("http server closed unexpected: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,16 +4,18 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
"github.com/vektah/gqlparser/v2/gqlerror"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proj"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
"github.com/vektah/gqlparser/v2/gqlerror"
|
||||
)
|
||||
|
||||
type Resolver struct {
|
||||
|
@ -22,20 +24,22 @@ type Resolver struct {
|
|||
ScopeService *scope.Scope
|
||||
}
|
||||
|
||||
type queryResolver struct{ *Resolver }
|
||||
type mutationResolver struct{ *Resolver }
|
||||
type (
|
||||
queryResolver struct{ *Resolver }
|
||||
mutationResolver struct{ *Resolver }
|
||||
)
|
||||
|
||||
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
|
||||
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
|
||||
|
||||
func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) {
|
||||
reqs, err := r.RequestLogService.FindRequests(ctx)
|
||||
if err == proj.ErrNoProject {
|
||||
if errors.Is(err, proj.ErrNoProject) {
|
||||
return nil, noActiveProjectErr(ctx)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not query repository for requests: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not query repository for requests: %v", err)
|
||||
}
|
||||
|
||||
logs := make([]HTTPRequestLog, len(reqs))
|
||||
|
||||
for i, req := range reqs {
|
||||
|
@ -43,6 +47,7 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs[i] = req
|
||||
}
|
||||
|
||||
|
@ -51,12 +56,12 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
|
|||
|
||||
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
|
||||
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
|
||||
if err == reqlog.ErrRequestNotFound {
|
||||
if errors.Is(err, reqlog.ErrRequestNotFound) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not get request by ID: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get request by ID: %v", err)
|
||||
}
|
||||
|
||||
req, err := parseRequestLog(log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -89,6 +94,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
|||
|
||||
if req.Request.Header != nil {
|
||||
log.Headers = make([]HTTPHeader, 0)
|
||||
|
||||
for key, values := range req.Request.Header {
|
||||
for _, value := range values {
|
||||
log.Headers = append(log.Headers, HTTPHeader{
|
||||
|
@ -106,15 +112,19 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
|||
StatusCode: req.Response.Response.StatusCode,
|
||||
}
|
||||
statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2)
|
||||
|
||||
if len(statusReasonSubs) == 2 {
|
||||
log.Response.StatusReason = statusReasonSubs[1]
|
||||
}
|
||||
|
||||
if len(req.Response.Body) > 0 {
|
||||
resBody := string(req.Response.Body)
|
||||
log.Response.Body = &resBody
|
||||
}
|
||||
|
||||
if req.Response.Response.Header != nil {
|
||||
log.Response.Headers = make([]HTTPHeader, 0)
|
||||
|
||||
for key, values := range req.Response.Response.Header {
|
||||
for _, value := range values {
|
||||
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
|
||||
|
@ -131,12 +141,12 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
|||
|
||||
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
|
||||
p, err := r.ProjectService.Open(ctx, name)
|
||||
if err == proj.ErrInvalidName {
|
||||
if errors.Is(err, proj.ErrInvalidName) {
|
||||
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not open project: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open project: %v", err)
|
||||
}
|
||||
|
||||
return &Project{
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
|
@ -145,11 +155,10 @@ func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Proje
|
|||
|
||||
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
||||
p, err := r.ProjectService.ActiveProject()
|
||||
if err == proj.ErrNoProject {
|
||||
if errors.Is(err, proj.ErrNoProject) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open project: %v", err)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not open project: %w", err)
|
||||
}
|
||||
|
||||
return &Project{
|
||||
|
@ -161,7 +170,7 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
|||
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
|
||||
p, err := r.ProjectService.Projects()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get projects: %v", err)
|
||||
return nil, fmt.Errorf("could not get projects: %w", err)
|
||||
}
|
||||
|
||||
projects := make([]Project, len(p))
|
||||
|
@ -184,21 +193,25 @@ func regexpToStringPtr(r *regexp.Regexp) *string {
|
|||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := r.String()
|
||||
|
||||
return &s
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
|
||||
if err := r.ProjectService.Close(); err != nil {
|
||||
return nil, fmt.Errorf("could not close project: %v", err)
|
||||
return nil, fmt.Errorf("could not close project: %w", err)
|
||||
}
|
||||
|
||||
return &CloseProjectResult{true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) {
|
||||
if err := r.ProjectService.Delete(name); err != nil {
|
||||
return nil, fmt.Errorf("could not delete project: %v", err)
|
||||
return nil, fmt.Errorf("could not delete project: %w", err)
|
||||
}
|
||||
|
||||
return &DeleteProjectResult{
|
||||
Success: true,
|
||||
}, nil
|
||||
|
@ -206,33 +219,40 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del
|
|||
|
||||
func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) {
|
||||
if err := r.RequestLogService.ClearRequests(ctx); err != nil {
|
||||
return nil, fmt.Errorf("could not clear request log: %v", err)
|
||||
return nil, fmt.Errorf("could not clear request log: %w", err)
|
||||
}
|
||||
|
||||
return &ClearHTTPRequestLogResult{true}, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) {
|
||||
rules := make([]scope.Rule, len(input))
|
||||
|
||||
for i, rule := range input {
|
||||
u, err := stringPtrToRegexp(rule.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL in scope rule: %v", err)
|
||||
return nil, fmt.Errorf("invalid URL in scope rule: %w", err)
|
||||
}
|
||||
|
||||
var headerKey, headerValue *regexp.Regexp
|
||||
|
||||
if rule.Header != nil {
|
||||
headerKey, err = stringPtrToRegexp(rule.Header.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid header key in scope rule: %v", err)
|
||||
return nil, fmt.Errorf("invalid header key in scope rule: %w", err)
|
||||
}
|
||||
|
||||
headerValue, err = stringPtrToRegexp(rule.Header.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid header value in scope rule: %v", err)
|
||||
return nil, fmt.Errorf("invalid header value in scope rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
body, err := stringPtrToRegexp(rule.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid body in scope rule: %v", err)
|
||||
return nil, fmt.Errorf("invalid body in scope rule: %w", err)
|
||||
}
|
||||
|
||||
rules[i] = scope.Rule{
|
||||
URL: u,
|
||||
Header: scope.Header{
|
||||
|
@ -244,7 +264,7 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput)
|
|||
}
|
||||
|
||||
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
|
||||
return nil, fmt.Errorf("could not set scope: %v", err)
|
||||
return nil, fmt.Errorf("could not set scope: %w", err)
|
||||
}
|
||||
|
||||
return scopeToScopeRules(rules), nil
|
||||
|
@ -260,14 +280,14 @@ func (r *mutationResolver) SetHTTPRequestLogFilter(
|
|||
) (*HTTPRequestLogFilter, error) {
|
||||
filter, err := findRequestsFilterFromInput(input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse request log filter: %v", err)
|
||||
return nil, fmt.Errorf("could not parse request log filter: %w", err)
|
||||
}
|
||||
|
||||
err = r.RequestLogService.SetRequestLogFilter(ctx, filter)
|
||||
if err == proj.ErrNoProject {
|
||||
if errors.Is(err, proj.ErrNoProject) {
|
||||
return nil, noActiveProjectErr(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not set request log filter: %v", err)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("could not set request log filter: %w", err)
|
||||
}
|
||||
|
||||
return findReqFilterToHTTPReqLogFilter(filter), nil
|
||||
|
@ -277,6 +297,7 @@ func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
|
|||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return regexp.Compile(*s)
|
||||
}
|
||||
|
||||
|
@ -290,8 +311,10 @@ func scopeToScopeRules(rules []scope.Rule) []ScopeRule {
|
|||
Value: regexpToStringPtr(rule.Header.Value),
|
||||
}
|
||||
}
|
||||
|
||||
scopeRules[i].Body = regexpToStringPtr(rule.Body)
|
||||
}
|
||||
|
||||
return scopeRules
|
||||
}
|
||||
|
||||
|
@ -299,14 +322,17 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo
|
|||
if input == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if input.OnlyInScope != nil {
|
||||
filter.OnlyInScope = *input.OnlyInScope
|
||||
}
|
||||
|
||||
if input.SearchExpression != nil && *input.SearchExpression != "" {
|
||||
expr, err := search.ParseQuery(*input.SearchExpression)
|
||||
if err != nil {
|
||||
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %v", err)
|
||||
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err)
|
||||
}
|
||||
|
||||
filter.RawSearchExpr = *input.SearchExpression
|
||||
filter.SearchExpr = expr
|
||||
}
|
||||
|
@ -319,6 +345,7 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H
|
|||
if findReqFilter == empty {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpReqLogFilter := &HTTPRequestLogFilter{
|
||||
OnlyInScope: findReqFilter.OnlyInScope,
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ func (u *reqURL) Scan(value interface{}) error {
|
|||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not parse URL: %v", err)
|
||||
return fmt.Errorf("sqlite: could not parse URL: %w", err)
|
||||
}
|
||||
|
||||
*u = reqURL(*parsed)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"sort"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/search"
|
||||
)
|
||||
|
||||
|
@ -57,20 +58,24 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
right, err := parseSearchExpr(expr.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sq.And{left, right}, nil
|
||||
case search.TokOpOr:
|
||||
left, err := parseSearchExpr(expr.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
right, err := parseSearchExpr(expr.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sq.Or{left, right}, nil
|
||||
}
|
||||
|
||||
|
@ -78,6 +83,7 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
|||
if !ok {
|
||||
return nil, errors.New("left operand must be a string literal")
|
||||
}
|
||||
|
||||
right, ok := expr.Right.(*search.StringLiteral)
|
||||
if !ok {
|
||||
return nil, errors.New("right operand must be a string literal")
|
||||
|
@ -113,14 +119,17 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
|||
func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) {
|
||||
// Sorting is not necessary, but makes it easier to do assertions in tests.
|
||||
sortedKeys := make([]string, 0, len(stringLiteralMap))
|
||||
|
||||
for _, v := range stringLiteralMap {
|
||||
sortedKeys = append(sortedKeys, v)
|
||||
}
|
||||
|
||||
sort.Strings(sortedKeys)
|
||||
|
||||
or := make(sq.Or, len(stringLiteralMap))
|
||||
for i, value := range sortedKeys {
|
||||
or[i] = sq.Like{value: "%" + strLiteral.Value + "%"}
|
||||
}
|
||||
|
||||
return or, nil
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
)
|
||||
|
||||
func TestParseSearchExpr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
searchExpr search.Expression
|
||||
|
@ -206,6 +208,8 @@ func TestParseSearchExpr(t *testing.T) {
|
|||
}
|
||||
|
||||
func assertError(t *testing.T, exp, got error) {
|
||||
t.Helper()
|
||||
|
||||
switch {
|
||||
case exp == nil && got != nil:
|
||||
t.Fatalf("expected: nil, got: %v", got)
|
||||
|
|
|
@ -15,14 +15,14 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proj"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/proj"
|
||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
)
|
||||
|
||||
var regexpFn = func(pattern string, value interface{}) (bool, error) {
|
||||
|
@ -55,10 +55,7 @@ type httpRequestLogsQuery struct {
|
|||
func init() {
|
||||
sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
if err := conn.RegisterFunc("regexp", regexpFn, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return conn.RegisterFunc("regexp", regexpFn, false)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -66,9 +63,10 @@ func init() {
|
|||
func New(dbPath string) (*Client, error) {
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
|
||||
return nil, fmt.Errorf("proj: could not create project directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
dbPath: dbPath,
|
||||
}, nil
|
||||
|
@ -85,17 +83,18 @@ func (c *Client) OpenProject(name string) error {
|
|||
|
||||
dbPath := filepath.Join(c.dbPath, name+".db")
|
||||
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
|
||||
|
||||
db, err := sqlx.Open("sqlite3_with_regexp", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not open database: %v", err)
|
||||
return fmt.Errorf("sqlite: could not open database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return fmt.Errorf("sqlite: could not ping database: %v", err)
|
||||
return fmt.Errorf("sqlite: could not ping database: %w", err)
|
||||
}
|
||||
|
||||
if err := prepareSchema(db); err != nil {
|
||||
return fmt.Errorf("sqlite: could not prepare schema: %v", err)
|
||||
return fmt.Errorf("sqlite: could not prepare schema: %w", err)
|
||||
}
|
||||
|
||||
c.db = db
|
||||
|
@ -107,10 +106,11 @@ func (c *Client) OpenProject(name string) error {
|
|||
func (c *Client) Projects() ([]proj.Project, error) {
|
||||
files, err := ioutil.ReadDir(c.dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not read projects directory: %w", err)
|
||||
}
|
||||
|
||||
projects := make([]proj.Project, len(files))
|
||||
|
||||
for i, file := range files {
|
||||
projName := strings.TrimSuffix(file.Name(), ".db")
|
||||
projects[i] = proj.Project{
|
||||
|
@ -132,7 +132,7 @@ func prepareSchema(db *sqlx.DB) error {
|
|||
timestamp DATETIME
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create http_requests table: %v", err)
|
||||
return fmt.Errorf("could not create http_requests table: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
|
||||
|
@ -145,7 +145,7 @@ func prepareSchema(db *sqlx.DB) error {
|
|||
timestamp DATETIME
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create http_responses table: %v", err)
|
||||
return fmt.Errorf("could not create http_responses table: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
|
||||
|
@ -156,7 +156,7 @@ func prepareSchema(db *sqlx.DB) error {
|
|||
value TEXT
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create http_headers table: %v", err)
|
||||
return fmt.Errorf("could not create http_headers table: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
|
||||
|
@ -164,7 +164,7 @@ func prepareSchema(db *sqlx.DB) error {
|
|||
settings TEXT
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create settings table: %v", err)
|
||||
return fmt.Errorf("could not create settings table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -175,8 +175,9 @@ func (c *Client) Close() error {
|
|||
if c.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.db.Close(); err != nil {
|
||||
return fmt.Errorf("sqlite: could not close database: %v", err)
|
||||
return fmt.Errorf("sqlite: could not close database: %w", err)
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
@ -187,7 +188,7 @@ func (c *Client) Close() error {
|
|||
|
||||
func (c *Client) DeleteProject(name string) error {
|
||||
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
|
||||
return fmt.Errorf("sqlite: could not remove database file: %v", err)
|
||||
return fmt.Errorf("sqlite: could not remove database file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -219,10 +220,12 @@ func (c *Client) ClearRequestLogs(ctx context.Context) error {
|
|||
if c.db == nil {
|
||||
return proj.ErrNoProject
|
||||
}
|
||||
|
||||
_, err := c.db.Exec("DELETE FROM http_requests")
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not delete requests: %v", err)
|
||||
return fmt.Errorf("sqlite: could not delete requests: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -247,11 +250,13 @@ func (c *Client) FindRequestLogs(
|
|||
|
||||
if filter.OnlyInScope && scope != nil {
|
||||
var ruleExpr []sq.Sqlizer
|
||||
|
||||
for _, rule := range scope.Rules() {
|
||||
if rule.URL != nil {
|
||||
ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String()))
|
||||
}
|
||||
}
|
||||
|
||||
if len(ruleExpr) > 0 {
|
||||
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
|
||||
}
|
||||
|
@ -260,37 +265,41 @@ func (c *Client) FindRequestLogs(
|
|||
if filter.SearchExpr != nil {
|
||||
sqlizer, err := parseSearchExpr(filter.SearchExpr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not parse search expression: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not parse search expression: %w", err)
|
||||
}
|
||||
|
||||
reqQuery = reqQuery.Where(sqlizer)
|
||||
}
|
||||
|
||||
sql, args, err := reqQuery.ToSql()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not parse query: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not parse query: %w", err)
|
||||
}
|
||||
|
||||
rows, err := c.db.QueryxContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var dto httpRequest
|
||||
|
||||
err = rows.StructScan(&dto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||
}
|
||||
|
||||
reqLogs = append(reqLogs, dto.toRequestLog())
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
|
||||
}
|
||||
rows.Close()
|
||||
defer rows.Close()
|
||||
|
||||
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not query headers: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not query headers: %w", err)
|
||||
}
|
||||
|
||||
return reqLogs, nil
|
||||
|
@ -300,35 +309,38 @@ func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Reque
|
|||
if c.db == nil {
|
||||
return reqlog.Request{}, proj.ErrNoProject
|
||||
}
|
||||
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
|
||||
|
||||
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
|
||||
reqQuery := sq.
|
||||
Select(httpReqLogsQuery.requestCols...).
|
||||
From("http_requests req").
|
||||
Where("req.id = ?")
|
||||
|
||||
if httpReqLogsQuery.joinResponse {
|
||||
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
|
||||
}
|
||||
|
||||
reqSQL, _, err := reqQuery.ToSql()
|
||||
if err != nil {
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %v", err)
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %w", err)
|
||||
}
|
||||
|
||||
row := c.db.QueryRowxContext(ctx, reqSQL, id)
|
||||
var dto httpRequest
|
||||
err = row.StructScan(&dto)
|
||||
if err == sql.ErrNoRows {
|
||||
return reqlog.Request{}, reqlog.ErrRequestNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %v", err)
|
||||
}
|
||||
reqLog := dto.toRequestLog()
|
||||
|
||||
var dto httpRequest
|
||||
|
||||
err = row.StructScan(&dto)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return reqlog.Request{}, reqlog.ErrRequestNotFound
|
||||
} else if err != nil {
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||
}
|
||||
|
||||
reqLog := dto.toRequestLog()
|
||||
reqLogs := []reqlog.Request{reqLog}
|
||||
|
||||
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %v", err)
|
||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %w", err)
|
||||
}
|
||||
|
||||
return reqLogs[0], nil
|
||||
|
@ -352,8 +364,9 @@ func (c *Client) AddRequestLog(
|
|||
|
||||
tx, err := c.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
|
||||
}
|
||||
|
||||
defer tx.Rollback()
|
||||
|
||||
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
|
||||
|
@ -364,7 +377,7 @@ func (c *Client) AddRequestLog(
|
|||
timestamp
|
||||
) VALUES (?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||
}
|
||||
defer reqStmt.Close()
|
||||
|
||||
|
@ -376,13 +389,14 @@ func (c *Client) AddRequestLog(
|
|||
reqLog.Timestamp,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
|
||||
}
|
||||
|
||||
reqID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
|
||||
}
|
||||
|
||||
reqLog.ID = reqID
|
||||
|
||||
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
||||
|
@ -391,17 +405,17 @@ func (c *Client) AddRequestLog(
|
|||
value
|
||||
) VALUES (?, ?, ?)`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||
}
|
||||
defer headerStmt.Close()
|
||||
|
||||
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return reqLog, nil
|
||||
|
@ -424,9 +438,10 @@ func (c *Client) AddResponseLog(
|
|||
Body: body,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
|
@ -439,7 +454,7 @@ func (c *Client) AddResponseLog(
|
|||
timestamp
|
||||
) VALUES (?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||
}
|
||||
defer resStmt.Close()
|
||||
|
||||
|
@ -457,13 +472,14 @@ func (c *Client) AddResponseLog(
|
|||
resLog.Timestamp,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
|
||||
}
|
||||
|
||||
resID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
|
||||
}
|
||||
|
||||
resLog.ID = resID
|
||||
|
||||
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
||||
|
@ -472,17 +488,17 @@ func (c *Client) AddResponseLog(
|
|||
value
|
||||
) VALUES (?, ?, ?)`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||
}
|
||||
defer headerStmt.Close()
|
||||
|
||||
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return resLog, nil
|
||||
|
@ -496,14 +512,14 @@ func (c *Client) UpsertSettings(ctx context.Context, module string, settings int
|
|||
|
||||
jsonSettings, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err)
|
||||
return fmt.Errorf("sqlite: could not encode settings as JSON: %w", err)
|
||||
}
|
||||
|
||||
_, err = c.db.ExecContext(ctx,
|
||||
`INSERT INTO settings (module, settings) VALUES (?, ?)
|
||||
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not insert scope settings: %v", err)
|
||||
return fmt.Errorf("sqlite: could not insert scope settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -515,17 +531,18 @@ func (c *Client) FindSettingsByModule(ctx context.Context, module string, settin
|
|||
}
|
||||
|
||||
var jsonSettings []byte
|
||||
|
||||
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
|
||||
|
||||
err := row.Scan(&jsonSettings)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return proj.ErrNoSettings
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("sqlite: could not scan row: %v", err)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
|
||||
return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err)
|
||||
return fmt.Errorf("sqlite: could not decode settings from JSON: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -535,42 +552,46 @@ func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.H
|
|||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
|
||||
return fmt.Errorf("could not execute statement: %v", err)
|
||||
return fmt.Errorf("could not execute statement: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
|
||||
headers := make(http.Header)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var key, value string
|
||||
err := rows.Scan(
|
||||
&key,
|
||||
&value,
|
||||
)
|
||||
|
||||
err := rows.Scan(&key, &value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||
}
|
||||
|
||||
headers.Add(key, value)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
|
||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
||||
var joinResponse bool
|
||||
var reqHeaderCols, resHeaderCols []string
|
||||
var (
|
||||
joinResponse bool
|
||||
reqHeaderCols, resHeaderCols []string
|
||||
)
|
||||
|
||||
opCtx := graphql.GetOperationContext(ctx)
|
||||
reqFields := graphql.CollectFieldsCtx(ctx, nil)
|
||||
|
@ -580,6 +601,7 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
|||
if col, ok := reqFieldToColumnMap[reqField.Name]; ok {
|
||||
reqCols = append(reqCols, "req."+col)
|
||||
}
|
||||
|
||||
if reqField.Name == "headers" {
|
||||
headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
||||
for _, headerField := range headerFields {
|
||||
|
@ -588,19 +610,23 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if reqField.Name == "response" {
|
||||
joinResponse = true
|
||||
resFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
||||
|
||||
for _, resField := range resFields {
|
||||
if resField.Name == "headers" {
|
||||
reqCols = append(reqCols, "res.id AS res_id")
|
||||
headerFields := graphql.CollectFields(opCtx, resField.Selections, nil)
|
||||
|
||||
for _, headerField := range headerFields {
|
||||
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
|
||||
resHeaderCols = append(resHeaderCols, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if col, ok := resFieldToColumnMap[resField.Name]; ok {
|
||||
reqCols = append(reqCols, "res."+col)
|
||||
}
|
||||
|
@ -627,18 +653,21 @@ func (c *Client) queryHeaders(
|
|||
From("http_headers").Where("req_id = ?").
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse request headers query: %v", err)
|
||||
return fmt.Errorf("could not parse request headers query: %w", err)
|
||||
}
|
||||
|
||||
reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not prepare statement: %v", err)
|
||||
return fmt.Errorf("could not prepare statement: %w", err)
|
||||
}
|
||||
defer reqHeadersStmt.Close()
|
||||
|
||||
for i := range reqLogs {
|
||||
headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not query request headers: %v", err)
|
||||
return fmt.Errorf("could not query request headers: %w", err)
|
||||
}
|
||||
|
||||
reqLogs[i].Request.Header = headers
|
||||
}
|
||||
}
|
||||
|
@ -649,21 +678,25 @@ func (c *Client) queryHeaders(
|
|||
From("http_headers").Where("res_id = ?").
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse response headers query: %v", err)
|
||||
return fmt.Errorf("could not parse response headers query: %w", err)
|
||||
}
|
||||
|
||||
resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not prepare statement: %v", err)
|
||||
return fmt.Errorf("could not prepare statement: %w", err)
|
||||
}
|
||||
defer resHeadersStmt.Close()
|
||||
|
||||
for i := range reqLogs {
|
||||
if reqLogs[i].Response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not query response headers: %v", err)
|
||||
return fmt.Errorf("could not query response headers: %w", err)
|
||||
}
|
||||
|
||||
reqLogs[i].Response.Response.Header = headers
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,8 +9,10 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
type OnProjectOpenFn func(name string) error
|
||||
type OnProjectCloseFn func(name string) error
|
||||
type (
|
||||
OnProjectOpenFn func(name string) error
|
||||
OnProjectCloseFn func(name string) error
|
||||
)
|
||||
|
||||
// Service is used for managing projects.
|
||||
type Service struct {
|
||||
|
@ -47,8 +49,9 @@ func (svc *Service) Close() error {
|
|||
defer svc.mu.Unlock()
|
||||
|
||||
closedProject := svc.activeProject
|
||||
|
||||
if err := svc.repo.Close(); err != nil {
|
||||
return fmt.Errorf("proj: could not close project: %v", err)
|
||||
return fmt.Errorf("proj: could not close project: %w", err)
|
||||
}
|
||||
|
||||
svc.activeProject = ""
|
||||
|
@ -63,12 +66,13 @@ func (svc *Service) Delete(name string) error {
|
|||
if name == "" {
|
||||
return errors.New("proj: name cannot be empty")
|
||||
}
|
||||
|
||||
if svc.activeProject == name {
|
||||
return fmt.Errorf("proj: project (%v) is active", name)
|
||||
}
|
||||
|
||||
if err := svc.repo.DeleteProject(name); err != nil {
|
||||
return fmt.Errorf("proj: could not delete project: %v", err)
|
||||
return fmt.Errorf("proj: could not delete project: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -85,11 +89,11 @@ func (svc *Service) Open(ctx context.Context, name string) (Project, error) {
|
|||
defer svc.mu.Unlock()
|
||||
|
||||
if err := svc.repo.Close(); err != nil {
|
||||
return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err)
|
||||
return Project{}, fmt.Errorf("proj: could not close previously open database: %w", err)
|
||||
}
|
||||
|
||||
if err := svc.repo.OpenProject(name); err != nil {
|
||||
return Project{}, fmt.Errorf("proj: could not open database: %v", err)
|
||||
return Project{}, fmt.Errorf("proj: could not open database: %w", err)
|
||||
}
|
||||
|
||||
svc.activeProject = name
|
||||
|
@ -115,7 +119,7 @@ func (svc *Service) ActiveProject() (Project, error) {
|
|||
func (svc *Service) Projects() ([]Project, error) {
|
||||
projects, err := svc.repo.Projects()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proj: could not get projects: %v", err)
|
||||
return nil, fmt.Errorf("proj: could not get projects: %w", err)
|
||||
}
|
||||
|
||||
return projects, nil
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
|
||||
|
||||
// CertConfig is a set of configuration values that are used to build TLS configs
|
||||
// capable of MITM
|
||||
// capable of MITM.
|
||||
type CertConfig struct {
|
||||
ca *x509.Certificate
|
||||
caPriv crypto.PrivateKey
|
||||
|
@ -40,6 +40,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pub := priv.Public()
|
||||
|
||||
// Subject Key Identifier support for end entity certificate.
|
||||
|
@ -48,6 +49,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := sha1.New()
|
||||
h.Write(pkixPubKey)
|
||||
keyID := h.Sum(nil)
|
||||
|
@ -67,58 +69,69 @@ func LoadOrCreateCA(caKeyFile, caCertFile string) (*x509.Certificate, *rsa.Priva
|
|||
if err == nil {
|
||||
caCert, err := x509.ParseCertificate(tlsCA.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not parse CA: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not parse CA: %w", err)
|
||||
}
|
||||
|
||||
caKey, ok := tlsCA.PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("proxy: private key is not RSA")
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %w", err)
|
||||
}
|
||||
|
||||
// Create directories for files if they don't exist yet.
|
||||
keyDir, _ := filepath.Split(caKeyFile)
|
||||
if keyDir != "" {
|
||||
if _, err := os.Stat(keyDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(keyDir, 0755)
|
||||
if err := os.MkdirAll(keyDir, 0755); err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not create directory for CA key: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keyDir, _ = filepath.Split(caCertFile)
|
||||
if keyDir != "" {
|
||||
if _, err := os.Stat("keyDir"); os.IsNotExist(err) {
|
||||
os.MkdirAll(keyDir, 0755)
|
||||
if err := os.MkdirAll(keyDir, 0755); err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not create directory for CA cert: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new CA keypair.
|
||||
caCert, caKey, err := NewCA("Hetty", "Hetty CA", time.Duration(365*24*time.Hour))
|
||||
caCert, caKey, err := NewCA("Hetty", "Hetty CA", 365*24*time.Hour)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %w", err)
|
||||
}
|
||||
|
||||
// Open CA certificate and key files for writing.
|
||||
certOut, err := os.Create(caCertFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %w", err)
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(caKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %w", err)
|
||||
}
|
||||
|
||||
// Write PEM blocks to CA certificate and key files.
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}); err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %w", err)
|
||||
}
|
||||
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(caKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %w", err)
|
||||
}
|
||||
|
||||
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
||||
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %v", err)
|
||||
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %w", err)
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
|
@ -130,6 +143,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pub := priv.Public()
|
||||
|
||||
// Subject Key Identifier support for end entity certificate.
|
||||
|
@ -138,6 +152,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
h := sha1.New()
|
||||
h.Write(pkixpub)
|
||||
keyID := h.Sum(nil)
|
||||
|
@ -187,8 +202,10 @@ func (c *CertConfig) TLSConfig() *tls.Config {
|
|||
if clientHello.ServerName == "" {
|
||||
return nil, errors.New("missing server name (SNI)")
|
||||
}
|
||||
|
||||
return c.cert(clientHello.ServerName)
|
||||
},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,12 @@ import (
|
|||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"github.com/dstotijn/hetty/pkg/scope"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
@ -27,8 +26,6 @@ type Proxy struct {
|
|||
// TODO: Add mutex for modifier funcs.
|
||||
reqModifiers []RequestModifyMiddleware
|
||||
resModifiers []ResponseModifyMiddleware
|
||||
|
||||
scope *scope.Scope
|
||||
}
|
||||
|
||||
// NewProxy returns a new Proxy.
|
||||
|
@ -55,7 +52,7 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
|
|||
|
||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodConnect {
|
||||
p.handleConnect(w, r)
|
||||
p.handleConnect(w)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -103,11 +100,12 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
|
|||
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
|
||||
// During the TLS handshake with the client, we use the proxy's CA config to
|
||||
// create a certificate on-the-fly.
|
||||
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
func (p *Proxy) handleConnect(w http.ResponseWriter) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w)
|
||||
writeError(w, r, http.StatusServiceUnavailable)
|
||||
writeError(w, http.StatusServiceUnavailable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -116,7 +114,8 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||
clientConn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Hijacking client connection failed: %v", err)
|
||||
writeError(w, r, http.StatusServiceUnavailable)
|
||||
writeError(w, http.StatusServiceUnavailable)
|
||||
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
@ -127,14 +126,15 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||
log.Printf("[ERROR] Securing client connection failed: %v", err)
|
||||
return
|
||||
}
|
||||
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
|
||||
|
||||
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
|
||||
l := &OnceAcceptListener{clientConnNotify.Conn}
|
||||
|
||||
err = http.Serve(l, p)
|
||||
if err != nil && err != ErrAlreadyAccepted {
|
||||
if err != nil && !errors.Is(err, ErrAlreadyAccepted) {
|
||||
log.Printf("[ERROR] Serving HTTP request failed: %v", err)
|
||||
}
|
||||
|
||||
<-clientConnNotify.closed
|
||||
}
|
||||
|
||||
|
@ -144,20 +144,22 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
|
|||
tlsConn := tls.Server(conn, tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, fmt.Errorf("handshake error: %v", err)
|
||||
return nil, fmt.Errorf("handshake error: %w", err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if err == context.Canceled {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[ERROR]: Proxy error: %v", err)
|
||||
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, r *http.Request, code int) {
|
||||
func writeError(w http.ResponseWriter, code int) {
|
||||
http.Error(w, http.StatusText(code), code)
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ type Repository interface {
|
|||
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
|
||||
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
|
||||
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
|
||||
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error)
|
||||
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) // nolint:lll
|
||||
ClearRequestLogs(ctx context.Context) error
|
||||
UpsertSettings(ctx context.Context, module string, settings interface{}) error
|
||||
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
|
||||
|
|
|
@ -24,9 +24,7 @@ const LogBypassedKey contextKey = 0
|
|||
|
||||
const moduleName = "reqlog"
|
||||
|
||||
var (
|
||||
ErrRequestNotFound = errors.New("reqlog: request not found")
|
||||
)
|
||||
var ErrRequestNotFound = errors.New("reqlog: request not found")
|
||||
|
||||
type Request struct {
|
||||
ID int64
|
||||
|
@ -74,12 +72,13 @@ func NewService(cfg Config) *Service {
|
|||
|
||||
cfg.ProjectService.OnProjectOpen(func(_ string) error {
|
||||
err := svc.loadSettings()
|
||||
if err == proj.ErrNoSettings {
|
||||
if errors.Is(err, proj.ErrNoSettings) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not load settings: %v", err)
|
||||
return fmt.Errorf("reqlog: could not load settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
cfg.ProjectService.OnProjectClose(func(_ string) error {
|
||||
|
@ -126,12 +125,13 @@ func (svc *Service) addResponse(
|
|||
if res.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err)
|
||||
return nil, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
body, err = ioutil.ReadAll(gzipReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
|
||||
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -141,18 +141,23 @@ func (svc *Service) addResponse(
|
|||
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
||||
return func(req *http.Request) {
|
||||
now := time.Now()
|
||||
|
||||
next(req)
|
||||
|
||||
clone := req.Clone(req.Context())
|
||||
|
||||
var body []byte
|
||||
|
||||
if req.Body != nil {
|
||||
// TODO: Use io.LimitReader.
|
||||
var err error
|
||||
|
||||
body, err = ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Could not read request body for logging: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
|
||||
|
@ -161,19 +166,21 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
|||
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
|
||||
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
||||
*req = *req.WithContext(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
|
||||
if err == proj.ErrNoProject {
|
||||
if errors.Is(err, proj.ErrNoProject) {
|
||||
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
||||
*req = *req.WithContext(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
} else if err != nil {
|
||||
log.Printf("[ERROR] Could not store request log: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
|
||||
*req = *req.WithContext(ctx)
|
||||
}
|
||||
|
@ -182,6 +189,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
|||
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
|
||||
return func(res *http.Response) error {
|
||||
now := time.Now()
|
||||
|
||||
if err := next(res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -200,8 +208,9 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
|
|||
// TODO: Use io.LimitReader.
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reqlog: could not read response body: %v", err)
|
||||
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||
}
|
||||
|
||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
go func() {
|
||||
|
@ -220,6 +229,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
|
|||
OnlyInScope bool
|
||||
RawSearchExpr string
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, &dto); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -234,6 +244,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filter.SearchExpr = expr
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package scope
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
@ -38,12 +39,13 @@ func New(repo Repository, projService *proj.Service) *Scope {
|
|||
|
||||
projService.OnProjectOpen(func(_ string) error {
|
||||
err := s.load(context.Background())
|
||||
if err == proj.ErrNoSettings {
|
||||
if errors.Is(err, proj.ErrNoSettings) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("scope: could not load scope: %v", err)
|
||||
return fmt.Errorf("scope: could not load scope: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
projService.OnProjectClose(func(_ string) error {
|
||||
|
@ -57,6 +59,7 @@ func New(repo Repository, projService *proj.Service) *Scope {
|
|||
func (s *Scope) Rules() []Rule {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.rules
|
||||
}
|
||||
|
||||
|
@ -65,12 +68,12 @@ func (s *Scope) load(ctx context.Context) error {
|
|||
defer s.mu.Unlock()
|
||||
|
||||
var rules []Rule
|
||||
|
||||
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
|
||||
if err == proj.ErrNoSettings {
|
||||
if errors.Is(err, proj.ErrNoSettings) {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("scope: could not load scope settings: %v", err)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("scope: could not load scope settings: %w", err)
|
||||
}
|
||||
|
||||
s.rules = rules
|
||||
|
@ -89,7 +92,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
|
|||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
|
||||
return fmt.Errorf("scope: cannot set rules in repository: %v", err)
|
||||
return fmt.Errorf("scope: cannot set rules in repository: %w", err)
|
||||
}
|
||||
|
||||
s.rules = rules
|
||||
|
@ -100,6 +103,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
|
|||
func (s *Scope) Match(req *http.Request, body []byte) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, rule := range s.rules {
|
||||
if matches := rule.Match(req, body); matches {
|
||||
return true
|
||||
|
@ -118,11 +122,13 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
|
|||
|
||||
for key, values := range req.Header {
|
||||
var keyMatches, valueMatches bool
|
||||
|
||||
if r.Header.Key != nil {
|
||||
if matches := r.Header.Key.MatchString(key); matches {
|
||||
keyMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
if r.Header.Value != nil {
|
||||
for _, value := range values {
|
||||
if matches := r.Header.Value.MatchString(value); matches {
|
||||
|
@ -154,15 +160,17 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
|
|||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (r Rule) MarshalJSON() ([]byte, error) {
|
||||
type headerDTO struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
type ruleDTO struct {
|
||||
URL string
|
||||
Header headerDTO
|
||||
Body string
|
||||
}
|
||||
type (
|
||||
headerDTO struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
ruleDTO struct {
|
||||
URL string
|
||||
Header headerDTO
|
||||
Body string
|
||||
}
|
||||
)
|
||||
|
||||
dto := ruleDTO{
|
||||
URL: regexpToString(r.URL),
|
||||
|
@ -178,15 +186,17 @@ func (r Rule) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (r *Rule) UnmarshalJSON(data []byte) error {
|
||||
type headerDTO struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
type ruleDTO struct {
|
||||
URL string
|
||||
Header headerDTO
|
||||
Body string
|
||||
}
|
||||
type (
|
||||
headerDTO struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
ruleDTO struct {
|
||||
URL string
|
||||
Header headerDTO
|
||||
Body string
|
||||
}
|
||||
)
|
||||
|
||||
var dto ruleDTO
|
||||
if err := json.Unmarshal(data, &dto); err != nil {
|
||||
|
@ -197,14 +207,17 @@ func (r *Rule) UnmarshalJSON(data []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headerKey, err := stringToRegexp(dto.Header.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headerValue, err := stringToRegexp(dto.Header.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := stringToRegexp(dto.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -226,6 +239,7 @@ func regexpToString(r *regexp.Regexp) string {
|
|||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return r.String()
|
||||
}
|
||||
|
||||
|
@ -233,5 +247,6 @@ func stringToRegexp(s string) (*regexp.Regexp, error) {
|
|||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return regexp.Compile(s)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ type PrefixExpression struct {
|
|||
Right Expression
|
||||
}
|
||||
|
||||
func (pe *PrefixExpression) expressionNode() {}
|
||||
func (pe *PrefixExpression) String() string {
|
||||
b := strings.Builder{}
|
||||
b.WriteString("(")
|
||||
|
@ -29,7 +28,6 @@ type InfixExpression struct {
|
|||
Right Expression
|
||||
}
|
||||
|
||||
func (ie *InfixExpression) expressionNode() {}
|
||||
func (ie *InfixExpression) String() string {
|
||||
b := strings.Builder{}
|
||||
b.WriteString("(")
|
||||
|
@ -47,7 +45,6 @@ type StringLiteral struct {
|
|||
Value string
|
||||
}
|
||||
|
||||
func (sl *StringLiteral) expressionNode() {}
|
||||
func (sl *StringLiteral) String() string {
|
||||
return sl.Value
|
||||
}
|
||||
|
|
|
@ -17,21 +17,21 @@ const eof = 0
|
|||
|
||||
// Token types.
|
||||
const (
|
||||
// Flow
|
||||
// Flow.
|
||||
TokInvalid TokenType = iota
|
||||
TokEOF
|
||||
TokParenOpen
|
||||
TokParenClose
|
||||
|
||||
// Literals
|
||||
// Literals.
|
||||
TokString
|
||||
|
||||
// Boolean operators
|
||||
// Boolean operators.
|
||||
TokOpNot
|
||||
TokOpAnd
|
||||
TokOpOr
|
||||
|
||||
// Comparison operators
|
||||
// Comparison operators.
|
||||
TokOpEq
|
||||
TokOpNotEq
|
||||
TokOpGt
|
||||
|
@ -98,6 +98,7 @@ func (tt TokenType) String() string {
|
|||
if typeString, ok := tokenTypeStrings[tt]; ok {
|
||||
return typeString
|
||||
}
|
||||
|
||||
return "<unknown>"
|
||||
}
|
||||
|
||||
|
@ -113,6 +114,7 @@ func (l *Lexer) read() (r rune) {
|
|||
l.width = 0
|
||||
return eof
|
||||
}
|
||||
|
||||
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
|
||||
l.pos += l.width
|
||||
|
||||
|
@ -124,6 +126,7 @@ func (l *Lexer) emit(tokenType TokenType) {
|
|||
Type: tokenType,
|
||||
Literal: l.input[l.start:l.pos],
|
||||
}
|
||||
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
|
@ -159,6 +162,7 @@ func begin(l *Lexer) stateFn {
|
|||
l.backup()
|
||||
l.emit(TokOpEq)
|
||||
}
|
||||
|
||||
return begin
|
||||
case '!':
|
||||
switch next := l.read(); next {
|
||||
|
@ -169,6 +173,7 @@ func begin(l *Lexer) stateFn {
|
|||
default:
|
||||
return l.errorf("invalid rune %v", r)
|
||||
}
|
||||
|
||||
return begin
|
||||
case '<':
|
||||
if next := l.read(); next == '=' {
|
||||
|
@ -177,6 +182,7 @@ func begin(l *Lexer) stateFn {
|
|||
l.backup()
|
||||
l.emit(TokOpLt)
|
||||
}
|
||||
|
||||
return begin
|
||||
case '>':
|
||||
if next := l.read(); next == '=' {
|
||||
|
@ -185,6 +191,7 @@ func begin(l *Lexer) stateFn {
|
|||
l.backup()
|
||||
l.emit(TokOpGt)
|
||||
}
|
||||
|
||||
return begin
|
||||
case '(':
|
||||
l.emit(TokParenOpen)
|
||||
|
@ -231,15 +238,18 @@ func unquotedString(l *Lexer) stateFn {
|
|||
case r == eof:
|
||||
l.backup()
|
||||
l.emitUnquotedString()
|
||||
|
||||
return begin
|
||||
case unicode.IsSpace(r):
|
||||
l.backup()
|
||||
l.emitUnquotedString()
|
||||
l.skip()
|
||||
|
||||
return begin
|
||||
case isReserved(r):
|
||||
l.backup()
|
||||
l.emitUnquotedString()
|
||||
|
||||
return begin
|
||||
}
|
||||
}
|
||||
|
@ -251,6 +261,7 @@ func (l *Lexer) emitUnquotedString() {
|
|||
l.emit(tokType)
|
||||
return
|
||||
}
|
||||
|
||||
l.emit(TokString)
|
||||
}
|
||||
|
||||
|
@ -260,5 +271,6 @@ func isReserved(r rune) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package search
|
|||
import "testing"
|
||||
|
||||
func TestNextToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
|
|
|
@ -18,8 +18,10 @@ const (
|
|||
precGroup
|
||||
)
|
||||
|
||||
type prefixParser func(*Parser) (Expression, error)
|
||||
type infixParser func(*Parser, Expression) (Expression, error)
|
||||
type (
|
||||
prefixParser func(*Parser) (Expression, error)
|
||||
infixParser func(*Parser, Expression) (Expression, error)
|
||||
)
|
||||
|
||||
var (
|
||||
prefixParsers = map[TokenType]prefixParser{}
|
||||
|
@ -77,7 +79,6 @@ func NewParser(l *Lexer) *Parser {
|
|||
p.nextToken()
|
||||
|
||||
return p
|
||||
|
||||
}
|
||||
|
||||
func ParseQuery(input string) (expr Expression, err error) {
|
||||
|
@ -91,18 +92,20 @@ func ParseQuery(input string) (expr Expression, err error) {
|
|||
|
||||
for !p.curTokenIs(TokEOF) {
|
||||
right, err := p.parseExpression(precLowest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search: could not parse expression: %v", err)
|
||||
}
|
||||
if expr == nil {
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("search: could not parse expression: %w", err)
|
||||
case expr == nil:
|
||||
expr = right
|
||||
} else {
|
||||
default:
|
||||
expr = &InfixExpression{
|
||||
Operator: TokOpAnd,
|
||||
Left: expr,
|
||||
Right: right,
|
||||
}
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
|
@ -122,18 +125,11 @@ func (p *Parser) peekTokenIs(t TokenType) bool {
|
|||
return p.peek.Type == t
|
||||
}
|
||||
|
||||
func (p *Parser) expectPeek(t TokenType) error {
|
||||
if !p.peekTokenIs(t) {
|
||||
return fmt.Errorf("expected next token to be %v, got %v", t, p.peek.Type)
|
||||
}
|
||||
p.nextToken()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) curPrecedence() precedence {
|
||||
if p, ok := tokenPrecedences[p.cur.Type]; ok {
|
||||
return p
|
||||
}
|
||||
|
||||
return precLowest
|
||||
}
|
||||
|
||||
|
@ -141,6 +137,7 @@ func (p *Parser) peekPrecedence() precedence {
|
|||
if p, ok := tokenPrecedences[p.peek.Type]; ok {
|
||||
return p
|
||||
}
|
||||
|
||||
return precLowest
|
||||
}
|
||||
|
||||
|
@ -152,7 +149,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
|
|||
|
||||
expr, err := prefixParser(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse expression prefix: %v", err)
|
||||
return nil, fmt.Errorf("could not parse expression prefix: %w", err)
|
||||
}
|
||||
|
||||
for !p.peekTokenIs(eof) && prec < p.peekPrecedence() {
|
||||
|
@ -165,7 +162,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
|
|||
|
||||
expr, err = infixParser(p, expr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse infix expression: %v", err)
|
||||
return nil, fmt.Errorf("could not parse infix expression: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,8 +178,9 @@ func parsePrefixExpression(p *Parser) (Expression, error) {
|
|||
|
||||
right, err := p.parseExpression(precPrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
|
||||
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
|
||||
}
|
||||
|
||||
expr.Right = right
|
||||
|
||||
return expr, nil
|
||||
|
@ -199,8 +197,9 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
|
|||
|
||||
right, err := p.parseExpression(prec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
|
||||
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
|
||||
}
|
||||
|
||||
expr.Right = right
|
||||
|
||||
return expr, nil
|
||||
|
@ -215,17 +214,19 @@ func parseGroupedExpression(p *Parser) (Expression, error) {
|
|||
|
||||
expr, err := p.parseExpression(precLowest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse grouped expression: %v", err)
|
||||
return nil, fmt.Errorf("could not parse grouped expression: %w", err)
|
||||
}
|
||||
|
||||
for p.nextToken(); !p.curTokenIs(TokParenClose); p.nextToken() {
|
||||
if p.curTokenIs(TokEOF) {
|
||||
return nil, fmt.Errorf("unexpected EOF: unmatched parentheses")
|
||||
}
|
||||
|
||||
right, err := p.parseExpression(precLowest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse expression: %v", err)
|
||||
return nil, fmt.Errorf("could not parse expression: %w", err)
|
||||
}
|
||||
|
||||
expr = &InfixExpression{
|
||||
Operator: TokOpAnd,
|
||||
Left: expr,
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
)
|
||||
|
||||
func TestParseQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
|
@ -233,6 +235,8 @@ func TestParseQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
func assertError(t *testing.T, exp, got error) {
|
||||
t.Helper()
|
||||
|
||||
switch {
|
||||
case exp == nil && got != nil:
|
||||
t.Fatalf("expected: nil, got: %v", got)
|
||||
|
|
Loading…
Reference in a new issue