mirror of
https://github.com/dstotijn/hetty
synced 2024-09-20 05:51:55 +00:00
Add linter, fix linting issue
This commit is contained in:
parent
ad3dc0da70
commit
ca3a729c36
18 changed files with 442 additions and 231 deletions
55
.golangci.yml
Normal file
55
.golangci.yml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
linters:
|
||||||
|
presets:
|
||||||
|
- bugs
|
||||||
|
- comment
|
||||||
|
- error
|
||||||
|
- format
|
||||||
|
- import
|
||||||
|
- metalinter
|
||||||
|
- module
|
||||||
|
- performance
|
||||||
|
- sql
|
||||||
|
- style
|
||||||
|
- test
|
||||||
|
- unused
|
||||||
|
disable:
|
||||||
|
- exhaustive
|
||||||
|
- exhaustivestruct
|
||||||
|
- gochecknoglobals
|
||||||
|
- gochecknoinits
|
||||||
|
- godox
|
||||||
|
- goerr113
|
||||||
|
- gomnd
|
||||||
|
- interfacer
|
||||||
|
- maligned
|
||||||
|
- nlreturn
|
||||||
|
- scopelint
|
||||||
|
- testpackage
|
||||||
|
- wrapcheck
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
errcheck:
|
||||||
|
ignore: "database/sql:Rollback"
|
||||||
|
gci:
|
||||||
|
local-prefixes: github.com/dstotijn/hetty
|
||||||
|
godot:
|
||||||
|
capital: true
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-rules:
|
||||||
|
- linters:
|
||||||
|
- gosec
|
||||||
|
# Ignore SHA1 usage.
|
||||||
|
text: "G(401|505):"
|
||||||
|
- linters:
|
||||||
|
- wsl
|
||||||
|
# Ignore cuddled defer statements.
|
||||||
|
text: "only one cuddle assignment allowed before defer statement"
|
||||||
|
- linters:
|
||||||
|
- nlreturn
|
||||||
|
# Ignore `break` without leading blank line.
|
||||||
|
text: "break with no blank line before"
|
||||||
|
|
||||||
|
run:
|
||||||
|
skip-files:
|
||||||
|
- cmd/hetty/rice-box.go
|
|
@ -2,25 +2,27 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/99designs/gqlgen/graphql/handler"
|
||||||
|
"github.com/99designs/gqlgen/graphql/playground"
|
||||||
rice "github.com/GeertJohan/go.rice"
|
rice "github.com/GeertJohan/go.rice"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/mitchellh/go-homedir"
|
||||||
|
|
||||||
"github.com/dstotijn/hetty/pkg/api"
|
"github.com/dstotijn/hetty/pkg/api"
|
||||||
"github.com/dstotijn/hetty/pkg/db/sqlite"
|
"github.com/dstotijn/hetty/pkg/db/sqlite"
|
||||||
"github.com/dstotijn/hetty/pkg/proj"
|
"github.com/dstotijn/hetty/pkg/proj"
|
||||||
"github.com/dstotijn/hetty/pkg/proxy"
|
"github.com/dstotijn/hetty/pkg/proxy"
|
||||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||||
"github.com/dstotijn/hetty/pkg/scope"
|
"github.com/dstotijn/hetty/pkg/scope"
|
||||||
|
|
||||||
"github.com/99designs/gqlgen/graphql/handler"
|
|
||||||
"github.com/99designs/gqlgen/graphql/playground"
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
"github.com/mitchellh/go-homedir"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -32,8 +34,16 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem", "CA certificate filepath. Creates a new CA certificate if file doesn't exist")
|
if err := run(); err != nil {
|
||||||
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem", "CA private key filepath. Creates a new CA private key if file doesn't exist")
|
log.Fatalf("[ERROR]: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
|
flag.StringVar(&caCertFile, "cert", "~/.hetty/hetty_cert.pem",
|
||||||
|
"CA certificate filepath. Creates a new CA certificate if file doesn't exist")
|
||||||
|
flag.StringVar(&caKeyFile, "key", "~/.hetty/hetty_key.pem",
|
||||||
|
"CA private key filepath. Creates a new CA private key if file doesn't exist")
|
||||||
flag.StringVar(&projPath, "projects", "~/.hetty/projects", "Projects directory path")
|
flag.StringVar(&projPath, "projects", "~/.hetty/projects", "Projects directory path")
|
||||||
flag.StringVar(&addr, "addr", ":8080", "TCP address to listen on, in the form \"host:port\"")
|
flag.StringVar(&addr, "addr", ":8080", "TCP address to listen on, in the form \"host:port\"")
|
||||||
flag.StringVar(&adminPath, "adminPath", "", "File path to admin build")
|
flag.StringVar(&adminPath, "adminPath", "", "File path to admin build")
|
||||||
|
@ -42,32 +52,34 @@ func main() {
|
||||||
// Expand `~` in filepaths.
|
// Expand `~` in filepaths.
|
||||||
caCertFile, err := homedir.Expand(caCertFile)
|
caCertFile, err := homedir.Expand(caCertFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not parse CA certificate filepath: %v", err)
|
return fmt.Errorf("could not parse CA certificate filepath: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
caKeyFile, err := homedir.Expand(caKeyFile)
|
caKeyFile, err := homedir.Expand(caKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not parse CA private key filepath: %v", err)
|
return fmt.Errorf("could not parse CA private key filepath: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
projPath, err := homedir.Expand(projPath)
|
projPath, err := homedir.Expand(projPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not parse projects filepath: %v", err)
|
return fmt.Errorf("could not parse projects filepath: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load existing CA certificate and key from disk, or generate and write
|
// Load existing CA certificate and key from disk, or generate and write
|
||||||
// to disk if no files exist yet.
|
// to disk if no files exist yet.
|
||||||
caCert, caKey, err := proxy.LoadOrCreateCA(caKeyFile, caCertFile)
|
caCert, caKey, err := proxy.LoadOrCreateCA(caKeyFile, caCertFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not create/load CA key pair: %v", err)
|
return fmt.Errorf("could not create/load CA key pair: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sqlite.New(projPath)
|
db, err := sqlite.New(projPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not initialize database client: %v", err)
|
return fmt.Errorf("could not initialize database client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
projService, err := proj.NewService(db)
|
projService, err := proj.NewService(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not create new project service: %v", err)
|
return fmt.Errorf("could not create new project service: %w", err)
|
||||||
}
|
}
|
||||||
defer projService.Close()
|
defer projService.Close()
|
||||||
|
|
||||||
|
@ -81,19 +93,21 @@ func main() {
|
||||||
|
|
||||||
p, err := proxy.NewProxy(caCert, caKey)
|
p, err := proxy.NewProxy(caCert, caKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not create Proxy: %v", err)
|
return fmt.Errorf("could not create proxy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.UseRequestModifier(reqLogService.RequestModifier)
|
p.UseRequestModifier(reqLogService.RequestModifier)
|
||||||
p.UseResponseModifier(reqLogService.ResponseModifier)
|
p.UseResponseModifier(reqLogService.ResponseModifier)
|
||||||
|
|
||||||
var adminHandler http.Handler
|
var adminHandler http.Handler
|
||||||
|
|
||||||
if adminPath == "" {
|
if adminPath == "" {
|
||||||
// Used for embedding with `rice`.
|
// Used for embedding with `rice`.
|
||||||
box, err := rice.FindBox("../../admin/dist")
|
box, err := rice.FindBox("../../admin/dist")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("[FATAL] Could not find embedded admin resources: %v", err)
|
return fmt.Errorf("could not find embedded admin resources: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
adminHandler = http.FileServer(box.HTTPBox())
|
adminHandler = http.FileServer(box.HTTPBox())
|
||||||
} else {
|
} else {
|
||||||
adminHandler = http.FileServer(http.Dir(adminPath))
|
adminHandler = http.FileServer(http.Dir(adminPath))
|
||||||
|
@ -109,7 +123,8 @@ func main() {
|
||||||
|
|
||||||
// GraphQL server.
|
// GraphQL server.
|
||||||
adminRouter.Path("/api/playground/").Handler(playground.Handler("GraphQL Playground", "/api/graphql/"))
|
adminRouter.Path("/api/playground/").Handler(playground.Handler("GraphQL Playground", "/api/graphql/"))
|
||||||
adminRouter.Path("/api/graphql/").Handler(handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
|
adminRouter.Path("/api/graphql/").Handler(
|
||||||
|
handler.NewDefaultServer(api.NewExecutableSchema(api.Config{Resolvers: &api.Resolver{
|
||||||
RequestLogService: reqLogService,
|
RequestLogService: reqLogService,
|
||||||
ProjectService: projService,
|
ProjectService: projService,
|
||||||
ScopeService: scope,
|
ScopeService: scope,
|
||||||
|
@ -128,8 +143,11 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[INFO] Running server on %v ...", addr)
|
log.Printf("[INFO] Running server on %v ...", addr)
|
||||||
|
|
||||||
err = s.ListenAndServe()
|
err = s.ListenAndServe()
|
||||||
if err != nil && err != http.ErrServerClosed {
|
if err != nil && errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("[FATAL] HTTP server closed: %v", err)
|
return fmt.Errorf("http server closed unexpected: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,16 +4,18 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/99designs/gqlgen/graphql"
|
"github.com/99designs/gqlgen/graphql"
|
||||||
|
"github.com/vektah/gqlparser/v2/gqlerror"
|
||||||
|
|
||||||
"github.com/dstotijn/hetty/pkg/proj"
|
"github.com/dstotijn/hetty/pkg/proj"
|
||||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||||
"github.com/dstotijn/hetty/pkg/scope"
|
"github.com/dstotijn/hetty/pkg/scope"
|
||||||
"github.com/dstotijn/hetty/pkg/search"
|
"github.com/dstotijn/hetty/pkg/search"
|
||||||
"github.com/vektah/gqlparser/v2/gqlerror"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
|
@ -22,20 +24,22 @@ type Resolver struct {
|
||||||
ScopeService *scope.Scope
|
ScopeService *scope.Scope
|
||||||
}
|
}
|
||||||
|
|
||||||
type queryResolver struct{ *Resolver }
|
type (
|
||||||
type mutationResolver struct{ *Resolver }
|
queryResolver struct{ *Resolver }
|
||||||
|
mutationResolver struct{ *Resolver }
|
||||||
|
)
|
||||||
|
|
||||||
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
|
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
|
||||||
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
|
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
|
||||||
|
|
||||||
func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) {
|
func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog, error) {
|
||||||
reqs, err := r.RequestLogService.FindRequests(ctx)
|
reqs, err := r.RequestLogService.FindRequests(ctx)
|
||||||
if err == proj.ErrNoProject {
|
if errors.Is(err, proj.ErrNoProject) {
|
||||||
return nil, noActiveProjectErr(ctx)
|
return nil, noActiveProjectErr(ctx)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not query repository for requests: %w", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not query repository for requests: %v", err)
|
|
||||||
}
|
|
||||||
logs := make([]HTTPRequestLog, len(reqs))
|
logs := make([]HTTPRequestLog, len(reqs))
|
||||||
|
|
||||||
for i, req := range reqs {
|
for i, req := range reqs {
|
||||||
|
@ -43,6 +47,7 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logs[i] = req
|
logs[i] = req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,12 +56,12 @@ func (r *queryResolver) HTTPRequestLogs(ctx context.Context) ([]HTTPRequestLog,
|
||||||
|
|
||||||
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
|
func (r *queryResolver) HTTPRequestLog(ctx context.Context, id int64) (*HTTPRequestLog, error) {
|
||||||
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
|
log, err := r.RequestLogService.FindRequestLogByID(ctx, id)
|
||||||
if err == reqlog.ErrRequestNotFound {
|
if errors.Is(err, reqlog.ErrRequestNotFound) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not get request by ID: %w", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not get request by ID: %v", err)
|
|
||||||
}
|
|
||||||
req, err := parseRequestLog(log)
|
req, err := parseRequestLog(log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -89,6 +94,7 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
||||||
|
|
||||||
if req.Request.Header != nil {
|
if req.Request.Header != nil {
|
||||||
log.Headers = make([]HTTPHeader, 0)
|
log.Headers = make([]HTTPHeader, 0)
|
||||||
|
|
||||||
for key, values := range req.Request.Header {
|
for key, values := range req.Request.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
log.Headers = append(log.Headers, HTTPHeader{
|
log.Headers = append(log.Headers, HTTPHeader{
|
||||||
|
@ -106,15 +112,19 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
||||||
StatusCode: req.Response.Response.StatusCode,
|
StatusCode: req.Response.Response.StatusCode,
|
||||||
}
|
}
|
||||||
statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2)
|
statusReasonSubs := strings.SplitN(req.Response.Response.Status, " ", 2)
|
||||||
|
|
||||||
if len(statusReasonSubs) == 2 {
|
if len(statusReasonSubs) == 2 {
|
||||||
log.Response.StatusReason = statusReasonSubs[1]
|
log.Response.StatusReason = statusReasonSubs[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Response.Body) > 0 {
|
if len(req.Response.Body) > 0 {
|
||||||
resBody := string(req.Response.Body)
|
resBody := string(req.Response.Body)
|
||||||
log.Response.Body = &resBody
|
log.Response.Body = &resBody
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Response.Response.Header != nil {
|
if req.Response.Response.Header != nil {
|
||||||
log.Response.Headers = make([]HTTPHeader, 0)
|
log.Response.Headers = make([]HTTPHeader, 0)
|
||||||
|
|
||||||
for key, values := range req.Response.Response.Header {
|
for key, values := range req.Response.Response.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
|
log.Response.Headers = append(log.Response.Headers, HTTPHeader{
|
||||||
|
@ -131,12 +141,12 @@ func parseRequestLog(req reqlog.Request) (HTTPRequestLog, error) {
|
||||||
|
|
||||||
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
|
func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Project, error) {
|
||||||
p, err := r.ProjectService.Open(ctx, name)
|
p, err := r.ProjectService.Open(ctx, name)
|
||||||
if err == proj.ErrInvalidName {
|
if errors.Is(err, proj.ErrInvalidName) {
|
||||||
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
|
return nil, gqlerror.Errorf("Project name must only contain alphanumeric or space chars.")
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not open project: %w", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open project: %v", err)
|
|
||||||
}
|
|
||||||
return &Project{
|
return &Project{
|
||||||
Name: p.Name,
|
Name: p.Name,
|
||||||
IsActive: p.IsActive,
|
IsActive: p.IsActive,
|
||||||
|
@ -145,11 +155,10 @@ func (r *mutationResolver) OpenProject(ctx context.Context, name string) (*Proje
|
||||||
|
|
||||||
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
||||||
p, err := r.ProjectService.ActiveProject()
|
p, err := r.ProjectService.ActiveProject()
|
||||||
if err == proj.ErrNoProject {
|
if errors.Is(err, proj.ErrNoProject) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("could not open project: %w", err)
|
||||||
return nil, fmt.Errorf("could not open project: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Project{
|
return &Project{
|
||||||
|
@ -161,7 +170,7 @@ func (r *queryResolver) ActiveProject(ctx context.Context) (*Project, error) {
|
||||||
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
|
func (r *queryResolver) Projects(ctx context.Context) ([]Project, error) {
|
||||||
p, err := r.ProjectService.Projects()
|
p, err := r.ProjectService.Projects()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not get projects: %v", err)
|
return nil, fmt.Errorf("could not get projects: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
projects := make([]Project, len(p))
|
projects := make([]Project, len(p))
|
||||||
|
@ -184,21 +193,25 @@ func regexpToStringPtr(r *regexp.Regexp) *string {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s := r.String()
|
s := r.String()
|
||||||
|
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
|
func (r *mutationResolver) CloseProject(ctx context.Context) (*CloseProjectResult, error) {
|
||||||
if err := r.ProjectService.Close(); err != nil {
|
if err := r.ProjectService.Close(); err != nil {
|
||||||
return nil, fmt.Errorf("could not close project: %v", err)
|
return nil, fmt.Errorf("could not close project: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CloseProjectResult{true}, nil
|
return &CloseProjectResult{true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) {
|
func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*DeleteProjectResult, error) {
|
||||||
if err := r.ProjectService.Delete(name); err != nil {
|
if err := r.ProjectService.Delete(name); err != nil {
|
||||||
return nil, fmt.Errorf("could not delete project: %v", err)
|
return nil, fmt.Errorf("could not delete project: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DeleteProjectResult{
|
return &DeleteProjectResult{
|
||||||
Success: true,
|
Success: true,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -206,33 +219,40 @@ func (r *mutationResolver) DeleteProject(ctx context.Context, name string) (*Del
|
||||||
|
|
||||||
func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) {
|
func (r *mutationResolver) ClearHTTPRequestLog(ctx context.Context) (*ClearHTTPRequestLogResult, error) {
|
||||||
if err := r.RequestLogService.ClearRequests(ctx); err != nil {
|
if err := r.RequestLogService.ClearRequests(ctx); err != nil {
|
||||||
return nil, fmt.Errorf("could not clear request log: %v", err)
|
return nil, fmt.Errorf("could not clear request log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ClearHTTPRequestLogResult{true}, nil
|
return &ClearHTTPRequestLogResult{true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) {
|
func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput) ([]ScopeRule, error) {
|
||||||
rules := make([]scope.Rule, len(input))
|
rules := make([]scope.Rule, len(input))
|
||||||
|
|
||||||
for i, rule := range input {
|
for i, rule := range input {
|
||||||
u, err := stringPtrToRegexp(rule.URL)
|
u, err := stringPtrToRegexp(rule.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid URL in scope rule: %v", err)
|
return nil, fmt.Errorf("invalid URL in scope rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var headerKey, headerValue *regexp.Regexp
|
var headerKey, headerValue *regexp.Regexp
|
||||||
|
|
||||||
if rule.Header != nil {
|
if rule.Header != nil {
|
||||||
headerKey, err = stringPtrToRegexp(rule.Header.Key)
|
headerKey, err = stringPtrToRegexp(rule.Header.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid header key in scope rule: %v", err)
|
return nil, fmt.Errorf("invalid header key in scope rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
headerValue, err = stringPtrToRegexp(rule.Header.Key)
|
headerValue, err = stringPtrToRegexp(rule.Header.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid header value in scope rule: %v", err)
|
return nil, fmt.Errorf("invalid header value in scope rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := stringPtrToRegexp(rule.Body)
|
body, err := stringPtrToRegexp(rule.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid body in scope rule: %v", err)
|
return nil, fmt.Errorf("invalid body in scope rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rules[i] = scope.Rule{
|
rules[i] = scope.Rule{
|
||||||
URL: u,
|
URL: u,
|
||||||
Header: scope.Header{
|
Header: scope.Header{
|
||||||
|
@ -244,7 +264,7 @@ func (r *mutationResolver) SetScope(ctx context.Context, input []ScopeRuleInput)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
|
if err := r.ScopeService.SetRules(ctx, rules); err != nil {
|
||||||
return nil, fmt.Errorf("could not set scope: %v", err)
|
return nil, fmt.Errorf("could not set scope: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return scopeToScopeRules(rules), nil
|
return scopeToScopeRules(rules), nil
|
||||||
|
@ -260,14 +280,14 @@ func (r *mutationResolver) SetHTTPRequestLogFilter(
|
||||||
) (*HTTPRequestLogFilter, error) {
|
) (*HTTPRequestLogFilter, error) {
|
||||||
filter, err := findRequestsFilterFromInput(input)
|
filter, err := findRequestsFilterFromInput(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse request log filter: %v", err)
|
return nil, fmt.Errorf("could not parse request log filter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.RequestLogService.SetRequestLogFilter(ctx, filter)
|
err = r.RequestLogService.SetRequestLogFilter(ctx, filter)
|
||||||
if err == proj.ErrNoProject {
|
if errors.Is(err, proj.ErrNoProject) {
|
||||||
return nil, noActiveProjectErr(ctx)
|
return nil, noActiveProjectErr(ctx)
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("could not set request log filter: %w", err)
|
||||||
return nil, fmt.Errorf("could not set request log filter: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return findReqFilterToHTTPReqLogFilter(filter), nil
|
return findReqFilterToHTTPReqLogFilter(filter), nil
|
||||||
|
@ -277,6 +297,7 @@ func stringPtrToRegexp(s *string) (*regexp.Regexp, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return regexp.Compile(*s)
|
return regexp.Compile(*s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,8 +311,10 @@ func scopeToScopeRules(rules []scope.Rule) []ScopeRule {
|
||||||
Value: regexpToStringPtr(rule.Header.Value),
|
Value: regexpToStringPtr(rule.Header.Value),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scopeRules[i].Body = regexpToStringPtr(rule.Body)
|
scopeRules[i].Body = regexpToStringPtr(rule.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
return scopeRules
|
return scopeRules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,14 +322,17 @@ func findRequestsFilterFromInput(input *HTTPRequestLogFilterInput) (filter reqlo
|
||||||
if input == nil {
|
if input == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.OnlyInScope != nil {
|
if input.OnlyInScope != nil {
|
||||||
filter.OnlyInScope = *input.OnlyInScope
|
filter.OnlyInScope = *input.OnlyInScope
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.SearchExpression != nil && *input.SearchExpression != "" {
|
if input.SearchExpression != nil && *input.SearchExpression != "" {
|
||||||
expr, err := search.ParseQuery(*input.SearchExpression)
|
expr, err := search.ParseQuery(*input.SearchExpression)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %v", err)
|
return reqlog.FindRequestsFilter{}, fmt.Errorf("could not parse search query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.RawSearchExpr = *input.SearchExpression
|
filter.RawSearchExpr = *input.SearchExpression
|
||||||
filter.SearchExpr = expr
|
filter.SearchExpr = expr
|
||||||
}
|
}
|
||||||
|
@ -319,6 +345,7 @@ func findReqFilterToHTTPReqLogFilter(findReqFilter reqlog.FindRequestsFilter) *H
|
||||||
if findReqFilter == empty {
|
if findReqFilter == empty {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReqLogFilter := &HTTPRequestLogFilter{
|
httpReqLogFilter := &HTTPRequestLogFilter{
|
||||||
OnlyInScope: findReqFilter.OnlyInScope,
|
OnlyInScope: findReqFilter.OnlyInScope,
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ func (u *reqURL) Scan(value interface{}) error {
|
||||||
|
|
||||||
parsed, err := url.Parse(rawURL)
|
parsed, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite: could not parse URL: %v", err)
|
return fmt.Errorf("sqlite: could not parse URL: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
*u = reqURL(*parsed)
|
*u = reqURL(*parsed)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
sq "github.com/Masterminds/squirrel"
|
sq "github.com/Masterminds/squirrel"
|
||||||
|
|
||||||
"github.com/dstotijn/hetty/pkg/search"
|
"github.com/dstotijn/hetty/pkg/search"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,20 +58,24 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
right, err := parseSearchExpr(expr.Right)
|
right, err := parseSearchExpr(expr.Right)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sq.And{left, right}, nil
|
return sq.And{left, right}, nil
|
||||||
case search.TokOpOr:
|
case search.TokOpOr:
|
||||||
left, err := parseSearchExpr(expr.Left)
|
left, err := parseSearchExpr(expr.Left)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
right, err := parseSearchExpr(expr.Right)
|
right, err := parseSearchExpr(expr.Right)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sq.Or{left, right}, nil
|
return sq.Or{left, right}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +83,7 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("left operand must be a string literal")
|
return nil, errors.New("left operand must be a string literal")
|
||||||
}
|
}
|
||||||
|
|
||||||
right, ok := expr.Right.(*search.StringLiteral)
|
right, ok := expr.Right.(*search.StringLiteral)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("right operand must be a string literal")
|
return nil, errors.New("right operand must be a string literal")
|
||||||
|
@ -113,14 +119,17 @@ func parseInfixExpr(expr *search.InfixExpression) (sq.Sqlizer, error) {
|
||||||
func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) {
|
func parseStringLiteral(strLiteral *search.StringLiteral) (sq.Sqlizer, error) {
|
||||||
// Sorting is not necessary, but makes it easier to do assertions in tests.
|
// Sorting is not necessary, but makes it easier to do assertions in tests.
|
||||||
sortedKeys := make([]string, 0, len(stringLiteralMap))
|
sortedKeys := make([]string, 0, len(stringLiteralMap))
|
||||||
|
|
||||||
for _, v := range stringLiteralMap {
|
for _, v := range stringLiteralMap {
|
||||||
sortedKeys = append(sortedKeys, v)
|
sortedKeys = append(sortedKeys, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(sortedKeys)
|
sort.Strings(sortedKeys)
|
||||||
|
|
||||||
or := make(sq.Or, len(stringLiteralMap))
|
or := make(sq.Or, len(stringLiteralMap))
|
||||||
for i, value := range sortedKeys {
|
for i, value := range sortedKeys {
|
||||||
or[i] = sq.Like{value: "%" + strLiteral.Value + "%"}
|
or[i] = sq.Like{value: "%" + strLiteral.Value + "%"}
|
||||||
}
|
}
|
||||||
|
|
||||||
return or, nil
|
return or, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSearchExpr(t *testing.T) {
|
func TestParseSearchExpr(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
searchExpr search.Expression
|
searchExpr search.Expression
|
||||||
|
@ -206,6 +208,8 @@ func TestParseSearchExpr(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertError(t *testing.T, exp, got error) {
|
func assertError(t *testing.T, exp, got error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case exp == nil && got != nil:
|
case exp == nil && got != nil:
|
||||||
t.Fatalf("expected: nil, got: %v", got)
|
t.Fatalf("expected: nil, got: %v", got)
|
||||||
|
|
|
@ -15,14 +15,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dstotijn/hetty/pkg/proj"
|
|
||||||
"github.com/dstotijn/hetty/pkg/reqlog"
|
|
||||||
"github.com/dstotijn/hetty/pkg/scope"
|
|
||||||
|
|
||||||
"github.com/99designs/gqlgen/graphql"
|
"github.com/99designs/gqlgen/graphql"
|
||||||
sq "github.com/Masterminds/squirrel"
|
sq "github.com/Masterminds/squirrel"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/mattn/go-sqlite3"
|
"github.com/mattn/go-sqlite3"
|
||||||
|
|
||||||
|
"github.com/dstotijn/hetty/pkg/proj"
|
||||||
|
"github.com/dstotijn/hetty/pkg/reqlog"
|
||||||
|
"github.com/dstotijn/hetty/pkg/scope"
|
||||||
)
|
)
|
||||||
|
|
||||||
var regexpFn = func(pattern string, value interface{}) (bool, error) {
|
var regexpFn = func(pattern string, value interface{}) (bool, error) {
|
||||||
|
@ -55,10 +55,7 @@ type httpRequestLogsQuery struct {
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{
|
sql.Register("sqlite3_with_regexp", &sqlite3.SQLiteDriver{
|
||||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||||
if err := conn.RegisterFunc("regexp", regexpFn, false); err != nil {
|
return conn.RegisterFunc("regexp", regexpFn, false)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -66,9 +63,10 @@ func init() {
|
||||||
func New(dbPath string) (*Client, error) {
|
func New(dbPath string) (*Client, error) {
|
||||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||||
return nil, fmt.Errorf("proj: could not create project directory: %v", err)
|
return nil, fmt.Errorf("proj: could not create project directory: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
dbPath: dbPath,
|
dbPath: dbPath,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -85,17 +83,18 @@ func (c *Client) OpenProject(name string) error {
|
||||||
|
|
||||||
dbPath := filepath.Join(c.dbPath, name+".db")
|
dbPath := filepath.Join(c.dbPath, name+".db")
|
||||||
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
|
dsn := fmt.Sprintf("file:%v?%v", dbPath, opts.Encode())
|
||||||
|
|
||||||
db, err := sqlx.Open("sqlite3_with_regexp", dsn)
|
db, err := sqlx.Open("sqlite3_with_regexp", dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite: could not open database: %v", err)
|
return fmt.Errorf("sqlite: could not open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.Ping(); err != nil {
|
if err := db.Ping(); err != nil {
|
||||||
return fmt.Errorf("sqlite: could not ping database: %v", err)
|
return fmt.Errorf("sqlite: could not ping database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := prepareSchema(db); err != nil {
|
if err := prepareSchema(db); err != nil {
|
||||||
return fmt.Errorf("sqlite: could not prepare schema: %v", err)
|
return fmt.Errorf("sqlite: could not prepare schema: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.db = db
|
c.db = db
|
||||||
|
@ -107,10 +106,11 @@ func (c *Client) OpenProject(name string) error {
|
||||||
func (c *Client) Projects() ([]proj.Project, error) {
|
func (c *Client) Projects() ([]proj.Project, error) {
|
||||||
files, err := ioutil.ReadDir(c.dbPath)
|
files, err := ioutil.ReadDir(c.dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not read projects directory: %v", err)
|
return nil, fmt.Errorf("sqlite: could not read projects directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
projects := make([]proj.Project, len(files))
|
projects := make([]proj.Project, len(files))
|
||||||
|
|
||||||
for i, file := range files {
|
for i, file := range files {
|
||||||
projName := strings.TrimSuffix(file.Name(), ".db")
|
projName := strings.TrimSuffix(file.Name(), ".db")
|
||||||
projects[i] = proj.Project{
|
projects[i] = proj.Project{
|
||||||
|
@ -132,7 +132,7 @@ func prepareSchema(db *sqlx.DB) error {
|
||||||
timestamp DATETIME
|
timestamp DATETIME
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create http_requests table: %v", err)
|
return fmt.Errorf("could not create http_requests table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_responses (
|
||||||
|
@ -145,7 +145,7 @@ func prepareSchema(db *sqlx.DB) error {
|
||||||
timestamp DATETIME
|
timestamp DATETIME
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create http_responses table: %v", err)
|
return fmt.Errorf("could not create http_responses table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS http_headers (
|
||||||
|
@ -156,7 +156,7 @@ func prepareSchema(db *sqlx.DB) error {
|
||||||
value TEXT
|
value TEXT
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create http_headers table: %v", err)
|
return fmt.Errorf("could not create http_headers table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS settings (
|
||||||
|
@ -164,7 +164,7 @@ func prepareSchema(db *sqlx.DB) error {
|
||||||
settings TEXT
|
settings TEXT
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create settings table: %v", err)
|
return fmt.Errorf("could not create settings table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -175,8 +175,9 @@ func (c *Client) Close() error {
|
||||||
if c.db == nil {
|
if c.db == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.db.Close(); err != nil {
|
if err := c.db.Close(); err != nil {
|
||||||
return fmt.Errorf("sqlite: could not close database: %v", err)
|
return fmt.Errorf("sqlite: could not close database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.db = nil
|
c.db = nil
|
||||||
|
@ -187,7 +188,7 @@ func (c *Client) Close() error {
|
||||||
|
|
||||||
func (c *Client) DeleteProject(name string) error {
|
func (c *Client) DeleteProject(name string) error {
|
||||||
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
|
if err := os.Remove(filepath.Join(c.dbPath, name+".db")); err != nil {
|
||||||
return fmt.Errorf("sqlite: could not remove database file: %v", err)
|
return fmt.Errorf("sqlite: could not remove database file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -219,10 +220,12 @@ func (c *Client) ClearRequestLogs(ctx context.Context) error {
|
||||||
if c.db == nil {
|
if c.db == nil {
|
||||||
return proj.ErrNoProject
|
return proj.ErrNoProject
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := c.db.Exec("DELETE FROM http_requests")
|
_, err := c.db.Exec("DELETE FROM http_requests")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite: could not delete requests: %v", err)
|
return fmt.Errorf("sqlite: could not delete requests: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,11 +250,13 @@ func (c *Client) FindRequestLogs(
|
||||||
|
|
||||||
if filter.OnlyInScope && scope != nil {
|
if filter.OnlyInScope && scope != nil {
|
||||||
var ruleExpr []sq.Sqlizer
|
var ruleExpr []sq.Sqlizer
|
||||||
|
|
||||||
for _, rule := range scope.Rules() {
|
for _, rule := range scope.Rules() {
|
||||||
if rule.URL != nil {
|
if rule.URL != nil {
|
||||||
ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String()))
|
ruleExpr = append(ruleExpr, sq.Expr("regexp(?, req.url)", rule.URL.String()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ruleExpr) > 0 {
|
if len(ruleExpr) > 0 {
|
||||||
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
|
reqQuery = reqQuery.Where(sq.Or(ruleExpr))
|
||||||
}
|
}
|
||||||
|
@ -260,37 +265,41 @@ func (c *Client) FindRequestLogs(
|
||||||
if filter.SearchExpr != nil {
|
if filter.SearchExpr != nil {
|
||||||
sqlizer, err := parseSearchExpr(filter.SearchExpr)
|
sqlizer, err := parseSearchExpr(filter.SearchExpr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not parse search expression: %v", err)
|
return nil, fmt.Errorf("sqlite: could not parse search expression: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqQuery = reqQuery.Where(sqlizer)
|
reqQuery = reqQuery.Where(sqlizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, args, err := reqQuery.ToSql()
|
sql, args, err := reqQuery.ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not parse query: %v", err)
|
return nil, fmt.Errorf("sqlite: could not parse query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := c.db.QueryxContext(ctx, sql, args...)
|
rows, err := c.db.QueryxContext(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
|
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dto httpRequest
|
var dto httpRequest
|
||||||
|
|
||||||
err = rows.StructScan(&dto)
|
err = rows.StructScan(&dto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
|
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqLogs = append(reqLogs, dto.toRequestLog())
|
reqLogs = append(reqLogs, dto.toRequestLog())
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
|
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
|
||||||
}
|
}
|
||||||
rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not query headers: %v", err)
|
return nil, fmt.Errorf("sqlite: could not query headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return reqLogs, nil
|
return reqLogs, nil
|
||||||
|
@ -300,35 +309,38 @@ func (c *Client) FindRequestLogByID(ctx context.Context, id int64) (reqlog.Reque
|
||||||
if c.db == nil {
|
if c.db == nil {
|
||||||
return reqlog.Request{}, proj.ErrNoProject
|
return reqlog.Request{}, proj.ErrNoProject
|
||||||
}
|
}
|
||||||
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
|
|
||||||
|
|
||||||
|
httpReqLogsQuery := parseHTTPRequestLogsQuery(ctx)
|
||||||
reqQuery := sq.
|
reqQuery := sq.
|
||||||
Select(httpReqLogsQuery.requestCols...).
|
Select(httpReqLogsQuery.requestCols...).
|
||||||
From("http_requests req").
|
From("http_requests req").
|
||||||
Where("req.id = ?")
|
Where("req.id = ?")
|
||||||
|
|
||||||
if httpReqLogsQuery.joinResponse {
|
if httpReqLogsQuery.joinResponse {
|
||||||
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
|
reqQuery = reqQuery.LeftJoin("http_responses res ON req.id = res.req_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
reqSQL, _, err := reqQuery.ToSql()
|
reqSQL, _, err := reqQuery.ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %v", err)
|
return reqlog.Request{}, fmt.Errorf("sqlite: could not parse query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
row := c.db.QueryRowxContext(ctx, reqSQL, id)
|
row := c.db.QueryRowxContext(ctx, reqSQL, id)
|
||||||
var dto httpRequest
|
|
||||||
err = row.StructScan(&dto)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return reqlog.Request{}, reqlog.ErrRequestNotFound
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %v", err)
|
|
||||||
}
|
|
||||||
reqLog := dto.toRequestLog()
|
|
||||||
|
|
||||||
|
var dto httpRequest
|
||||||
|
|
||||||
|
err = row.StructScan(&dto)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return reqlog.Request{}, reqlog.ErrRequestNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return reqlog.Request{}, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog := dto.toRequestLog()
|
||||||
reqLogs := []reqlog.Request{reqLog}
|
reqLogs := []reqlog.Request{reqLog}
|
||||||
|
|
||||||
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
if err := c.queryHeaders(ctx, httpReqLogsQuery, reqLogs); err != nil {
|
||||||
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %v", err)
|
return reqlog.Request{}, fmt.Errorf("sqlite: could not query headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return reqLogs[0], nil
|
return reqLogs[0], nil
|
||||||
|
@ -352,8 +364,9 @@ func (c *Client) AddRequestLog(
|
||||||
|
|
||||||
tx, err := c.db.BeginTxx(ctx, nil)
|
tx, err := c.db.BeginTxx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
|
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
|
reqStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_requests (
|
||||||
|
@ -364,7 +377,7 @@ func (c *Client) AddRequestLog(
|
||||||
timestamp
|
timestamp
|
||||||
) VALUES (?, ?, ?, ?, ?)`)
|
) VALUES (?, ?, ?, ?, ?)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer reqStmt.Close()
|
defer reqStmt.Close()
|
||||||
|
|
||||||
|
@ -376,13 +389,14 @@ func (c *Client) AddRequestLog(
|
||||||
reqLog.Timestamp,
|
reqLog.Timestamp,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqID, err := result.LastInsertId()
|
reqID, err := result.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
|
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqLog.ID = reqID
|
reqLog.ID = reqID
|
||||||
|
|
||||||
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
||||||
|
@ -391,17 +405,17 @@ func (c *Client) AddRequestLog(
|
||||||
value
|
value
|
||||||
) VALUES (?, ?, ?)`)
|
) VALUES (?, ?, ?)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer headerStmt.Close()
|
defer headerStmt.Close()
|
||||||
|
|
||||||
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
|
err = insertHeaders(ctx, headerStmt, reqID, reqLog.Request.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
|
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
|
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return reqLog, nil
|
return reqLog, nil
|
||||||
|
@ -424,9 +438,10 @@ func (c *Client) AddResponseLog(
|
||||||
Body: body,
|
Body: body,
|
||||||
Timestamp: timestamp,
|
Timestamp: timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := c.db.BeginTx(ctx, nil)
|
tx, err := c.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not start transaction: %v", err)
|
return nil, fmt.Errorf("sqlite: could not start transaction: %w", err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
@ -439,7 +454,7 @@ func (c *Client) AddResponseLog(
|
||||||
timestamp
|
timestamp
|
||||||
) VALUES (?, ?, ?, ?, ?, ?)`)
|
) VALUES (?, ?, ?, ?, ?, ?)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer resStmt.Close()
|
defer resStmt.Close()
|
||||||
|
|
||||||
|
@ -457,13 +472,14 @@ func (c *Client) AddResponseLog(
|
||||||
resLog.Timestamp,
|
resLog.Timestamp,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not execute statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not execute statement: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resID, err := result.LastInsertId()
|
resID, err := result.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not get last insert ID: %v", err)
|
return nil, fmt.Errorf("sqlite: could not get last insert ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resLog.ID = resID
|
resLog.ID = resID
|
||||||
|
|
||||||
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
headerStmt, err := tx.PrepareContext(ctx, `INSERT INTO http_headers (
|
||||||
|
@ -472,17 +488,17 @@ func (c *Client) AddResponseLog(
|
||||||
value
|
value
|
||||||
) VALUES (?, ?, ?)`)
|
) VALUES (?, ?, ?)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not prepare statement: %v", err)
|
return nil, fmt.Errorf("sqlite: could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer headerStmt.Close()
|
defer headerStmt.Close()
|
||||||
|
|
||||||
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
|
err = insertHeaders(ctx, headerStmt, resID, resLog.Response.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not insert http headers: %v", err)
|
return nil, fmt.Errorf("sqlite: could not insert http headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not commit transaction: %v", err)
|
return nil, fmt.Errorf("sqlite: could not commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return resLog, nil
|
return resLog, nil
|
||||||
|
@ -496,14 +512,14 @@ func (c *Client) UpsertSettings(ctx context.Context, module string, settings int
|
||||||
|
|
||||||
jsonSettings, err := json.Marshal(settings)
|
jsonSettings, err := json.Marshal(settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite: could not encode settings as JSON: %v", err)
|
return fmt.Errorf("sqlite: could not encode settings as JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.db.ExecContext(ctx,
|
_, err = c.db.ExecContext(ctx,
|
||||||
`INSERT INTO settings (module, settings) VALUES (?, ?)
|
`INSERT INTO settings (module, settings) VALUES (?, ?)
|
||||||
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
|
ON CONFLICT(module) DO UPDATE SET settings = ?`, module, jsonSettings, jsonSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sqlite: could not insert scope settings: %v", err)
|
return fmt.Errorf("sqlite: could not insert scope settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -515,17 +531,18 @@ func (c *Client) FindSettingsByModule(ctx context.Context, module string, settin
|
||||||
}
|
}
|
||||||
|
|
||||||
var jsonSettings []byte
|
var jsonSettings []byte
|
||||||
|
|
||||||
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
|
row := c.db.QueryRowContext(ctx, `SELECT settings FROM settings WHERE module = ?`, module)
|
||||||
|
|
||||||
err := row.Scan(&jsonSettings)
|
err := row.Scan(&jsonSettings)
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return proj.ErrNoSettings
|
return proj.ErrNoSettings
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||||
return fmt.Errorf("sqlite: could not scan row: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
|
if err := json.Unmarshal(jsonSettings, &settings); err != nil {
|
||||||
return fmt.Errorf("sqlite: could not decode settings from JSON: %v", err)
|
return fmt.Errorf("sqlite: could not decode settings from JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -535,42 +552,46 @@ func insertHeaders(ctx context.Context, stmt *sql.Stmt, id int64, headers http.H
|
||||||
for key, values := range headers {
|
for key, values := range headers {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
|
if _, err := stmt.ExecContext(ctx, id, key, value); err != nil {
|
||||||
return fmt.Errorf("could not execute statement: %v", err)
|
return fmt.Errorf("could not execute statement: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
|
func findHeaders(ctx context.Context, stmt *sql.Stmt, id int64) (http.Header, error) {
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
|
|
||||||
rows, err := stmt.QueryContext(ctx, id)
|
rows, err := stmt.QueryContext(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not execute query: %v", err)
|
return nil, fmt.Errorf("sqlite: could not execute query: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var key, value string
|
var key, value string
|
||||||
err := rows.Scan(
|
|
||||||
&key,
|
err := rows.Scan(&key, &value)
|
||||||
&value,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not scan row: %v", err)
|
return nil, fmt.Errorf("sqlite: could not scan row: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
headers.Add(key, value)
|
headers.Add(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("sqlite: could not iterate over rows: %v", err)
|
return nil, fmt.Errorf("sqlite: could not iterate over rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers, nil
|
return headers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
||||||
var joinResponse bool
|
var (
|
||||||
var reqHeaderCols, resHeaderCols []string
|
joinResponse bool
|
||||||
|
reqHeaderCols, resHeaderCols []string
|
||||||
|
)
|
||||||
|
|
||||||
opCtx := graphql.GetOperationContext(ctx)
|
opCtx := graphql.GetOperationContext(ctx)
|
||||||
reqFields := graphql.CollectFieldsCtx(ctx, nil)
|
reqFields := graphql.CollectFieldsCtx(ctx, nil)
|
||||||
|
@ -580,6 +601,7 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
||||||
if col, ok := reqFieldToColumnMap[reqField.Name]; ok {
|
if col, ok := reqFieldToColumnMap[reqField.Name]; ok {
|
||||||
reqCols = append(reqCols, "req."+col)
|
reqCols = append(reqCols, "req."+col)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqField.Name == "headers" {
|
if reqField.Name == "headers" {
|
||||||
headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
headerFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
||||||
for _, headerField := range headerFields {
|
for _, headerField := range headerFields {
|
||||||
|
@ -588,19 +610,23 @@ func parseHTTPRequestLogsQuery(ctx context.Context) httpRequestLogsQuery {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqField.Name == "response" {
|
if reqField.Name == "response" {
|
||||||
joinResponse = true
|
joinResponse = true
|
||||||
resFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
resFields := graphql.CollectFields(opCtx, reqField.Selections, nil)
|
||||||
|
|
||||||
for _, resField := range resFields {
|
for _, resField := range resFields {
|
||||||
if resField.Name == "headers" {
|
if resField.Name == "headers" {
|
||||||
reqCols = append(reqCols, "res.id AS res_id")
|
reqCols = append(reqCols, "res.id AS res_id")
|
||||||
headerFields := graphql.CollectFields(opCtx, resField.Selections, nil)
|
headerFields := graphql.CollectFields(opCtx, resField.Selections, nil)
|
||||||
|
|
||||||
for _, headerField := range headerFields {
|
for _, headerField := range headerFields {
|
||||||
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
|
if col, ok := headerFieldToColumnMap[headerField.Name]; ok {
|
||||||
resHeaderCols = append(resHeaderCols, col)
|
resHeaderCols = append(resHeaderCols, col)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if col, ok := resFieldToColumnMap[resField.Name]; ok {
|
if col, ok := resFieldToColumnMap[resField.Name]; ok {
|
||||||
reqCols = append(reqCols, "res."+col)
|
reqCols = append(reqCols, "res."+col)
|
||||||
}
|
}
|
||||||
|
@ -627,18 +653,21 @@ func (c *Client) queryHeaders(
|
||||||
From("http_headers").Where("req_id = ?").
|
From("http_headers").Where("req_id = ?").
|
||||||
ToSql()
|
ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not parse request headers query: %v", err)
|
return fmt.Errorf("could not parse request headers query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery)
|
reqHeadersStmt, err := c.db.PrepareContext(ctx, reqHeadersQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not prepare statement: %v", err)
|
return fmt.Errorf("could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer reqHeadersStmt.Close()
|
defer reqHeadersStmt.Close()
|
||||||
|
|
||||||
for i := range reqLogs {
|
for i := range reqLogs {
|
||||||
headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID)
|
headers, err := findHeaders(ctx, reqHeadersStmt, reqLogs[i].ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not query request headers: %v", err)
|
return fmt.Errorf("could not query request headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqLogs[i].Request.Header = headers
|
reqLogs[i].Request.Header = headers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -649,21 +678,25 @@ func (c *Client) queryHeaders(
|
||||||
From("http_headers").Where("res_id = ?").
|
From("http_headers").Where("res_id = ?").
|
||||||
ToSql()
|
ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not parse response headers query: %v", err)
|
return fmt.Errorf("could not parse response headers query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery)
|
resHeadersStmt, err := c.db.PrepareContext(ctx, resHeadersQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not prepare statement: %v", err)
|
return fmt.Errorf("could not prepare statement: %w", err)
|
||||||
}
|
}
|
||||||
defer resHeadersStmt.Close()
|
defer resHeadersStmt.Close()
|
||||||
|
|
||||||
for i := range reqLogs {
|
for i := range reqLogs {
|
||||||
if reqLogs[i].Response == nil {
|
if reqLogs[i].Response == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID)
|
headers, err := findHeaders(ctx, resHeadersStmt, reqLogs[i].Response.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not query response headers: %v", err)
|
return fmt.Errorf("could not query response headers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqLogs[i].Response.Response.Header = headers
|
reqLogs[i].Response.Response.Header = headers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,10 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OnProjectOpenFn func(name string) error
|
type (
|
||||||
type OnProjectCloseFn func(name string) error
|
OnProjectOpenFn func(name string) error
|
||||||
|
OnProjectCloseFn func(name string) error
|
||||||
|
)
|
||||||
|
|
||||||
// Service is used for managing projects.
|
// Service is used for managing projects.
|
||||||
type Service struct {
|
type Service struct {
|
||||||
|
@ -47,8 +49,9 @@ func (svc *Service) Close() error {
|
||||||
defer svc.mu.Unlock()
|
defer svc.mu.Unlock()
|
||||||
|
|
||||||
closedProject := svc.activeProject
|
closedProject := svc.activeProject
|
||||||
|
|
||||||
if err := svc.repo.Close(); err != nil {
|
if err := svc.repo.Close(); err != nil {
|
||||||
return fmt.Errorf("proj: could not close project: %v", err)
|
return fmt.Errorf("proj: could not close project: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.activeProject = ""
|
svc.activeProject = ""
|
||||||
|
@ -63,12 +66,13 @@ func (svc *Service) Delete(name string) error {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return errors.New("proj: name cannot be empty")
|
return errors.New("proj: name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.activeProject == name {
|
if svc.activeProject == name {
|
||||||
return fmt.Errorf("proj: project (%v) is active", name)
|
return fmt.Errorf("proj: project (%v) is active", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := svc.repo.DeleteProject(name); err != nil {
|
if err := svc.repo.DeleteProject(name); err != nil {
|
||||||
return fmt.Errorf("proj: could not delete project: %v", err)
|
return fmt.Errorf("proj: could not delete project: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -85,11 +89,11 @@ func (svc *Service) Open(ctx context.Context, name string) (Project, error) {
|
||||||
defer svc.mu.Unlock()
|
defer svc.mu.Unlock()
|
||||||
|
|
||||||
if err := svc.repo.Close(); err != nil {
|
if err := svc.repo.Close(); err != nil {
|
||||||
return Project{}, fmt.Errorf("proj: could not close previously open database: %v", err)
|
return Project{}, fmt.Errorf("proj: could not close previously open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := svc.repo.OpenProject(name); err != nil {
|
if err := svc.repo.OpenProject(name); err != nil {
|
||||||
return Project{}, fmt.Errorf("proj: could not open database: %v", err)
|
return Project{}, fmt.Errorf("proj: could not open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.activeProject = name
|
svc.activeProject = name
|
||||||
|
@ -115,7 +119,7 @@ func (svc *Service) ActiveProject() (Project, error) {
|
||||||
func (svc *Service) Projects() ([]Project, error) {
|
func (svc *Service) Projects() ([]Project, error) {
|
||||||
projects, err := svc.repo.Projects()
|
projects, err := svc.repo.Projects()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("proj: could not get projects: %v", err)
|
return nil, fmt.Errorf("proj: could not get projects: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return projects, nil
|
return projects, nil
|
||||||
|
|
|
@ -25,7 +25,7 @@ import (
|
||||||
var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
|
var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20))
|
||||||
|
|
||||||
// CertConfig is a set of configuration values that are used to build TLS configs
|
// CertConfig is a set of configuration values that are used to build TLS configs
|
||||||
// capable of MITM
|
// capable of MITM.
|
||||||
type CertConfig struct {
|
type CertConfig struct {
|
||||||
ca *x509.Certificate
|
ca *x509.Certificate
|
||||||
caPriv crypto.PrivateKey
|
caPriv crypto.PrivateKey
|
||||||
|
@ -40,6 +40,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pub := priv.Public()
|
pub := priv.Public()
|
||||||
|
|
||||||
// Subject Key Identifier support for end entity certificate.
|
// Subject Key Identifier support for end entity certificate.
|
||||||
|
@ -48,6 +49,7 @@ func NewCertConfig(ca *x509.Certificate, caPrivKey crypto.PrivateKey) (*CertConf
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write(pkixPubKey)
|
h.Write(pkixPubKey)
|
||||||
keyID := h.Sum(nil)
|
keyID := h.Sum(nil)
|
||||||
|
@ -67,58 +69,69 @@ func LoadOrCreateCA(caKeyFile, caCertFile string) (*x509.Certificate, *rsa.Priva
|
||||||
if err == nil {
|
if err == nil {
|
||||||
caCert, err := x509.ParseCertificate(tlsCA.Certificate[0])
|
caCert, err := x509.ParseCertificate(tlsCA.Certificate[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not parse CA: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not parse CA: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
caKey, ok := tlsCA.PrivateKey.(*rsa.PrivateKey)
|
caKey, ok := tlsCA.PrivateKey.(*rsa.PrivateKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("proxy: private key is not RSA")
|
return nil, nil, errors.New("proxy: private key is not RSA")
|
||||||
}
|
}
|
||||||
|
|
||||||
return caCert, caKey, nil
|
return caCert, caKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not load CA key pair: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create directories for files if they don't exist yet.
|
// Create directories for files if they don't exist yet.
|
||||||
keyDir, _ := filepath.Split(caKeyFile)
|
keyDir, _ := filepath.Split(caKeyFile)
|
||||||
if keyDir != "" {
|
if keyDir != "" {
|
||||||
if _, err := os.Stat(keyDir); os.IsNotExist(err) {
|
if _, err := os.Stat(keyDir); os.IsNotExist(err) {
|
||||||
os.MkdirAll(keyDir, 0755)
|
if err := os.MkdirAll(keyDir, 0755); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("proxy: could not create directory for CA key: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
keyDir, _ = filepath.Split(caCertFile)
|
keyDir, _ = filepath.Split(caCertFile)
|
||||||
if keyDir != "" {
|
if keyDir != "" {
|
||||||
if _, err := os.Stat("keyDir"); os.IsNotExist(err) {
|
if _, err := os.Stat("keyDir"); os.IsNotExist(err) {
|
||||||
os.MkdirAll(keyDir, 0755)
|
if err := os.MkdirAll(keyDir, 0755); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("proxy: could not create directory for CA cert: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new CA keypair.
|
// Create new CA keypair.
|
||||||
caCert, caKey, err := NewCA("Hetty", "Hetty CA", time.Duration(365*24*time.Hour))
|
caCert, caKey, err := NewCA("Hetty", "Hetty CA", 365*24*time.Hour)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not generate new CA keypair: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open CA certificate and key files for writing.
|
// Open CA certificate and key files for writing.
|
||||||
certOut, err := os.Create(caCertFile)
|
certOut, err := os.Create(caCertFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not open cert file for writing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
keyOut, err := os.OpenFile(caKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
keyOut, err := os.OpenFile(caKeyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not open key file for writing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write PEM blocks to CA certificate and key files.
|
// Write PEM blocks to CA certificate and key files.
|
||||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}); err != nil {
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw}); err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not write CA certificate to disk: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
privBytes, err := x509.MarshalPKCS8PrivateKey(caKey)
|
privBytes, err := x509.MarshalPKCS8PrivateKey(caKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not convert private key to DER format: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
||||||
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %v", err)
|
return nil, nil, fmt.Errorf("proxy: could not write CA key to disk: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return caCert, caKey, nil
|
return caCert, caKey, nil
|
||||||
|
@ -130,6 +143,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pub := priv.Public()
|
pub := priv.Public()
|
||||||
|
|
||||||
// Subject Key Identifier support for end entity certificate.
|
// Subject Key Identifier support for end entity certificate.
|
||||||
|
@ -138,6 +152,7 @@ func NewCA(name, organization string, validity time.Duration) (*x509.Certificate
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write(pkixpub)
|
h.Write(pkixpub)
|
||||||
keyID := h.Sum(nil)
|
keyID := h.Sum(nil)
|
||||||
|
@ -187,8 +202,10 @@ func (c *CertConfig) TLSConfig() *tls.Config {
|
||||||
if clientHello.ServerName == "" {
|
if clientHello.ServerName == "" {
|
||||||
return nil, errors.New("missing server name (SNI)")
|
return nil, errors.New("missing server name (SNI)")
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.cert(clientHello.ServerName)
|
return c.cert(clientHello.ServerName)
|
||||||
},
|
},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
NextProtos: []string{"http/1.1"},
|
NextProtos: []string{"http/1.1"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,13 +5,12 @@ import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
|
||||||
"github.com/dstotijn/hetty/pkg/scope"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey int
|
type contextKey int
|
||||||
|
@ -27,8 +26,6 @@ type Proxy struct {
|
||||||
// TODO: Add mutex for modifier funcs.
|
// TODO: Add mutex for modifier funcs.
|
||||||
reqModifiers []RequestModifyMiddleware
|
reqModifiers []RequestModifyMiddleware
|
||||||
resModifiers []ResponseModifyMiddleware
|
resModifiers []ResponseModifyMiddleware
|
||||||
|
|
||||||
scope *scope.Scope
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProxy returns a new Proxy.
|
// NewProxy returns a new Proxy.
|
||||||
|
@ -55,7 +52,7 @@ func NewProxy(ca *x509.Certificate, key crypto.PrivateKey) (*Proxy, error) {
|
||||||
|
|
||||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == http.MethodConnect {
|
if r.Method == http.MethodConnect {
|
||||||
p.handleConnect(w, r)
|
p.handleConnect(w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,11 +100,12 @@ func (p *Proxy) modifyResponse(res *http.Response) error {
|
||||||
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
|
// handleConnect hijacks the incoming HTTP request and sets up an HTTP tunnel.
|
||||||
// During the TLS handshake with the client, we use the proxy's CA config to
|
// During the TLS handshake with the client, we use the proxy's CA config to
|
||||||
// create a certificate on-the-fly.
|
// create a certificate on-the-fly.
|
||||||
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) handleConnect(w http.ResponseWriter) {
|
||||||
hj, ok := w.(http.Hijacker)
|
hj, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w)
|
log.Printf("[ERROR] handleConnect: ResponseWriter is not a http.Hijacker (type: %T)", w)
|
||||||
writeError(w, r, http.StatusServiceUnavailable)
|
writeError(w, http.StatusServiceUnavailable)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +114,8 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
clientConn, _, err := hj.Hijack()
|
clientConn, _, err := hj.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] Hijacking client connection failed: %v", err)
|
log.Printf("[ERROR] Hijacking client connection failed: %v", err)
|
||||||
writeError(w, r, http.StatusServiceUnavailable)
|
writeError(w, http.StatusServiceUnavailable)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
@ -127,14 +126,15 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Printf("[ERROR] Securing client connection failed: %v", err)
|
log.Printf("[ERROR] Securing client connection failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
|
|
||||||
|
|
||||||
|
clientConnNotify := ConnNotify{clientConn, make(chan struct{})}
|
||||||
l := &OnceAcceptListener{clientConnNotify.Conn}
|
l := &OnceAcceptListener{clientConnNotify.Conn}
|
||||||
|
|
||||||
err = http.Serve(l, p)
|
err = http.Serve(l, p)
|
||||||
if err != nil && err != ErrAlreadyAccepted {
|
if err != nil && !errors.Is(err, ErrAlreadyAccepted) {
|
||||||
log.Printf("[ERROR] Serving HTTP request failed: %v", err)
|
log.Printf("[ERROR] Serving HTTP request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
<-clientConnNotify.closed
|
<-clientConnNotify.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,20 +144,22 @@ func (p *Proxy) clientTLSConn(conn net.Conn) (*tls.Conn, error) {
|
||||||
tlsConn := tls.Server(conn, tlsConfig)
|
tlsConn := tls.Server(conn, tlsConfig)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
tlsConn.Close()
|
tlsConn.Close()
|
||||||
return nil, fmt.Errorf("handshake error: %v", err)
|
return nil, fmt.Errorf("handshake error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tlsConn, nil
|
return tlsConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
func errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
if err == context.Canceled {
|
if errors.Is(err, context.Canceled) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[ERROR]: Proxy error: %v", err)
|
log.Printf("[ERROR]: Proxy error: %v", err)
|
||||||
|
|
||||||
w.WriteHeader(http.StatusBadGateway)
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeError(w http.ResponseWriter, r *http.Request, code int) {
|
func writeError(w http.ResponseWriter, code int) {
|
||||||
http.Error(w, http.StatusText(code), code)
|
http.Error(w, http.StatusText(code), code)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ type Repository interface {
|
||||||
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
|
FindRequestLogs(ctx context.Context, filter FindRequestsFilter, scope *scope.Scope) ([]Request, error)
|
||||||
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
|
FindRequestLogByID(ctx context.Context, id int64) (Request, error)
|
||||||
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
|
AddRequestLog(ctx context.Context, req http.Request, body []byte, timestamp time.Time) (*Request, error)
|
||||||
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error)
|
AddResponseLog(ctx context.Context, reqID int64, res http.Response, body []byte, timestamp time.Time) (*Response, error) // nolint:lll
|
||||||
ClearRequestLogs(ctx context.Context) error
|
ClearRequestLogs(ctx context.Context) error
|
||||||
UpsertSettings(ctx context.Context, module string, settings interface{}) error
|
UpsertSettings(ctx context.Context, module string, settings interface{}) error
|
||||||
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
|
FindSettingsByModule(ctx context.Context, module string, settings interface{}) error
|
||||||
|
|
|
@ -24,9 +24,7 @@ const LogBypassedKey contextKey = 0
|
||||||
|
|
||||||
const moduleName = "reqlog"
|
const moduleName = "reqlog"
|
||||||
|
|
||||||
var (
|
var ErrRequestNotFound = errors.New("reqlog: request not found")
|
||||||
ErrRequestNotFound = errors.New("reqlog: request not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
ID int64
|
ID int64
|
||||||
|
@ -74,12 +72,13 @@ func NewService(cfg Config) *Service {
|
||||||
|
|
||||||
cfg.ProjectService.OnProjectOpen(func(_ string) error {
|
cfg.ProjectService.OnProjectOpen(func(_ string) error {
|
||||||
err := svc.loadSettings()
|
err := svc.loadSettings()
|
||||||
if err == proj.ErrNoSettings {
|
if errors.Is(err, proj.ErrNoSettings) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reqlog: could not load settings: %v", err)
|
return fmt.Errorf("reqlog: could not load settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
cfg.ProjectService.OnProjectClose(func(_ string) error {
|
cfg.ProjectService.OnProjectClose(func(_ string) error {
|
||||||
|
@ -126,12 +125,13 @@ func (svc *Service) addResponse(
|
||||||
if res.Header.Get("Content-Encoding") == "gzip" {
|
if res.Header.Get("Content-Encoding") == "gzip" {
|
||||||
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
|
gzipReader, err := gzip.NewReader(bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("reqlog: could not create gzip reader: %v", err)
|
return nil, fmt.Errorf("reqlog: could not create gzip reader: %w", err)
|
||||||
}
|
}
|
||||||
defer gzipReader.Close()
|
defer gzipReader.Close()
|
||||||
|
|
||||||
body, err = ioutil.ReadAll(gzipReader)
|
body, err = ioutil.ReadAll(gzipReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %v", err)
|
return nil, fmt.Errorf("reqlog: could not read gzipped response body: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,18 +141,23 @@ func (svc *Service) addResponse(
|
||||||
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestModifyFunc {
|
||||||
return func(req *http.Request) {
|
return func(req *http.Request) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
next(req)
|
next(req)
|
||||||
|
|
||||||
clone := req.Clone(req.Context())
|
clone := req.Clone(req.Context())
|
||||||
|
|
||||||
var body []byte
|
var body []byte
|
||||||
|
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
// TODO: Use io.LimitReader.
|
// TODO: Use io.LimitReader.
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
body, err = ioutil.ReadAll(req.Body)
|
body, err = ioutil.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] Could not read request body for logging: %v", err)
|
log.Printf("[ERROR] Could not read request body for logging: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,19 +166,21 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||||
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
|
if svc.BypassOutOfScopeRequests && !svc.scope.Match(clone, body) {
|
||||||
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
||||||
*req = *req.WithContext(ctx)
|
*req = *req.WithContext(ctx)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
|
reqLog, err := svc.addRequest(req.Context(), *clone, body, now)
|
||||||
if err == proj.ErrNoProject {
|
if errors.Is(err, proj.ErrNoProject) {
|
||||||
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
ctx := context.WithValue(req.Context(), LogBypassedKey, true)
|
||||||
*req = *req.WithContext(ctx)
|
*req = *req.WithContext(ctx)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Printf("[ERROR] Could not store request log: %v", err)
|
log.Printf("[ERROR] Could not store request log: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
|
ctx := context.WithValue(req.Context(), proxy.ReqIDKey, reqLog.ID)
|
||||||
*req = *req.WithContext(ctx)
|
*req = *req.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
@ -182,6 +189,7 @@ func (svc *Service) RequestModifier(next proxy.RequestModifyFunc) proxy.RequestM
|
||||||
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
|
func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.ResponseModifyFunc {
|
||||||
return func(res *http.Response) error {
|
return func(res *http.Response) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
if err := next(res); err != nil {
|
if err := next(res); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -200,8 +208,9 @@ func (svc *Service) ResponseModifier(next proxy.ResponseModifyFunc) proxy.Respon
|
||||||
// TODO: Use io.LimitReader.
|
// TODO: Use io.LimitReader.
|
||||||
body, err := ioutil.ReadAll(res.Body)
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reqlog: could not read response body: %v", err)
|
return fmt.Errorf("reqlog: could not read response body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
res.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -220,6 +229,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
|
||||||
OnlyInScope bool
|
OnlyInScope bool
|
||||||
RawSearchExpr string
|
RawSearchExpr string
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(b, &dto); err != nil {
|
if err := json.Unmarshal(b, &dto); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -234,6 +244,7 @@ func (f *FindRequestsFilter) UnmarshalJSON(b []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.SearchExpr = expr
|
filter.SearchExpr = expr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package scope
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -38,12 +39,13 @@ func New(repo Repository, projService *proj.Service) *Scope {
|
||||||
|
|
||||||
projService.OnProjectOpen(func(_ string) error {
|
projService.OnProjectOpen(func(_ string) error {
|
||||||
err := s.load(context.Background())
|
err := s.load(context.Background())
|
||||||
if err == proj.ErrNoSettings {
|
if errors.Is(err, proj.ErrNoSettings) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("scope: could not load scope: %v", err)
|
return fmt.Errorf("scope: could not load scope: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
projService.OnProjectClose(func(_ string) error {
|
projService.OnProjectClose(func(_ string) error {
|
||||||
|
@ -57,6 +59,7 @@ func New(repo Repository, projService *proj.Service) *Scope {
|
||||||
func (s *Scope) Rules() []Rule {
|
func (s *Scope) Rules() []Rule {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
return s.rules
|
return s.rules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,12 +68,12 @@ func (s *Scope) load(ctx context.Context) error {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var rules []Rule
|
var rules []Rule
|
||||||
|
|
||||||
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
|
err := s.repo.FindSettingsByModule(ctx, moduleName, &rules)
|
||||||
if err == proj.ErrNoSettings {
|
if errors.Is(err, proj.ErrNoSettings) {
|
||||||
return err
|
return err
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("scope: could not load scope settings: %w", err)
|
||||||
return fmt.Errorf("scope: could not load scope settings: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.rules = rules
|
s.rules = rules
|
||||||
|
@ -89,7 +92,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
|
if err := s.repo.UpsertSettings(ctx, moduleName, rules); err != nil {
|
||||||
return fmt.Errorf("scope: cannot set rules in repository: %v", err)
|
return fmt.Errorf("scope: cannot set rules in repository: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.rules = rules
|
s.rules = rules
|
||||||
|
@ -100,6 +103,7 @@ func (s *Scope) SetRules(ctx context.Context, rules []Rule) error {
|
||||||
func (s *Scope) Match(req *http.Request, body []byte) bool {
|
func (s *Scope) Match(req *http.Request, body []byte) bool {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
for _, rule := range s.rules {
|
for _, rule := range s.rules {
|
||||||
if matches := rule.Match(req, body); matches {
|
if matches := rule.Match(req, body); matches {
|
||||||
return true
|
return true
|
||||||
|
@ -118,11 +122,13 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
|
||||||
|
|
||||||
for key, values := range req.Header {
|
for key, values := range req.Header {
|
||||||
var keyMatches, valueMatches bool
|
var keyMatches, valueMatches bool
|
||||||
|
|
||||||
if r.Header.Key != nil {
|
if r.Header.Key != nil {
|
||||||
if matches := r.Header.Key.MatchString(key); matches {
|
if matches := r.Header.Key.MatchString(key); matches {
|
||||||
keyMatches = true
|
keyMatches = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Header.Value != nil {
|
if r.Header.Value != nil {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if matches := r.Header.Value.MatchString(value); matches {
|
if matches := r.Header.Value.MatchString(value); matches {
|
||||||
|
@ -154,15 +160,17 @@ func (r Rule) Match(req *http.Request, body []byte) bool {
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler.
|
// MarshalJSON implements json.Marshaler.
|
||||||
func (r Rule) MarshalJSON() ([]byte, error) {
|
func (r Rule) MarshalJSON() ([]byte, error) {
|
||||||
type headerDTO struct {
|
type (
|
||||||
|
headerDTO struct {
|
||||||
Key string
|
Key string
|
||||||
Value string
|
Value string
|
||||||
}
|
}
|
||||||
type ruleDTO struct {
|
ruleDTO struct {
|
||||||
URL string
|
URL string
|
||||||
Header headerDTO
|
Header headerDTO
|
||||||
Body string
|
Body string
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
dto := ruleDTO{
|
dto := ruleDTO{
|
||||||
URL: regexpToString(r.URL),
|
URL: regexpToString(r.URL),
|
||||||
|
@ -178,15 +186,17 @@ func (r Rule) MarshalJSON() ([]byte, error) {
|
||||||
|
|
||||||
// UnmarshalJSON implements json.Unmarshaler.
|
// UnmarshalJSON implements json.Unmarshaler.
|
||||||
func (r *Rule) UnmarshalJSON(data []byte) error {
|
func (r *Rule) UnmarshalJSON(data []byte) error {
|
||||||
type headerDTO struct {
|
type (
|
||||||
|
headerDTO struct {
|
||||||
Key string
|
Key string
|
||||||
Value string
|
Value string
|
||||||
}
|
}
|
||||||
type ruleDTO struct {
|
ruleDTO struct {
|
||||||
URL string
|
URL string
|
||||||
Header headerDTO
|
Header headerDTO
|
||||||
Body string
|
Body string
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
var dto ruleDTO
|
var dto ruleDTO
|
||||||
if err := json.Unmarshal(data, &dto); err != nil {
|
if err := json.Unmarshal(data, &dto); err != nil {
|
||||||
|
@ -197,14 +207,17 @@ func (r *Rule) UnmarshalJSON(data []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
headerKey, err := stringToRegexp(dto.Header.Key)
|
headerKey, err := stringToRegexp(dto.Header.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
headerValue, err := stringToRegexp(dto.Header.Value)
|
headerValue, err := stringToRegexp(dto.Header.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := stringToRegexp(dto.Body)
|
body, err := stringToRegexp(dto.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -226,6 +239,7 @@ func regexpToString(r *regexp.Regexp) string {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.String()
|
return r.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,5 +247,6 @@ func stringToRegexp(s string) (*regexp.Regexp, error) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return regexp.Compile(s)
|
return regexp.Compile(s)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@ type PrefixExpression struct {
|
||||||
Right Expression
|
Right Expression
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pe *PrefixExpression) expressionNode() {}
|
|
||||||
func (pe *PrefixExpression) String() string {
|
func (pe *PrefixExpression) String() string {
|
||||||
b := strings.Builder{}
|
b := strings.Builder{}
|
||||||
b.WriteString("(")
|
b.WriteString("(")
|
||||||
|
@ -29,7 +28,6 @@ type InfixExpression struct {
|
||||||
Right Expression
|
Right Expression
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ie *InfixExpression) expressionNode() {}
|
|
||||||
func (ie *InfixExpression) String() string {
|
func (ie *InfixExpression) String() string {
|
||||||
b := strings.Builder{}
|
b := strings.Builder{}
|
||||||
b.WriteString("(")
|
b.WriteString("(")
|
||||||
|
@ -47,7 +45,6 @@ type StringLiteral struct {
|
||||||
Value string
|
Value string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sl *StringLiteral) expressionNode() {}
|
|
||||||
func (sl *StringLiteral) String() string {
|
func (sl *StringLiteral) String() string {
|
||||||
return sl.Value
|
return sl.Value
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,21 +17,21 @@ const eof = 0
|
||||||
|
|
||||||
// Token types.
|
// Token types.
|
||||||
const (
|
const (
|
||||||
// Flow
|
// Flow.
|
||||||
TokInvalid TokenType = iota
|
TokInvalid TokenType = iota
|
||||||
TokEOF
|
TokEOF
|
||||||
TokParenOpen
|
TokParenOpen
|
||||||
TokParenClose
|
TokParenClose
|
||||||
|
|
||||||
// Literals
|
// Literals.
|
||||||
TokString
|
TokString
|
||||||
|
|
||||||
// Boolean operators
|
// Boolean operators.
|
||||||
TokOpNot
|
TokOpNot
|
||||||
TokOpAnd
|
TokOpAnd
|
||||||
TokOpOr
|
TokOpOr
|
||||||
|
|
||||||
// Comparison operators
|
// Comparison operators.
|
||||||
TokOpEq
|
TokOpEq
|
||||||
TokOpNotEq
|
TokOpNotEq
|
||||||
TokOpGt
|
TokOpGt
|
||||||
|
@ -98,6 +98,7 @@ func (tt TokenType) String() string {
|
||||||
if typeString, ok := tokenTypeStrings[tt]; ok {
|
if typeString, ok := tokenTypeStrings[tt]; ok {
|
||||||
return typeString
|
return typeString
|
||||||
}
|
}
|
||||||
|
|
||||||
return "<unknown>"
|
return "<unknown>"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +114,7 @@ func (l *Lexer) read() (r rune) {
|
||||||
l.width = 0
|
l.width = 0
|
||||||
return eof
|
return eof
|
||||||
}
|
}
|
||||||
|
|
||||||
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
|
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
|
||||||
l.pos += l.width
|
l.pos += l.width
|
||||||
|
|
||||||
|
@ -124,6 +126,7 @@ func (l *Lexer) emit(tokenType TokenType) {
|
||||||
Type: tokenType,
|
Type: tokenType,
|
||||||
Literal: l.input[l.start:l.pos],
|
Literal: l.input[l.start:l.pos],
|
||||||
}
|
}
|
||||||
|
|
||||||
l.start = l.pos
|
l.start = l.pos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,6 +162,7 @@ func begin(l *Lexer) stateFn {
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emit(TokOpEq)
|
l.emit(TokOpEq)
|
||||||
}
|
}
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case '!':
|
case '!':
|
||||||
switch next := l.read(); next {
|
switch next := l.read(); next {
|
||||||
|
@ -169,6 +173,7 @@ func begin(l *Lexer) stateFn {
|
||||||
default:
|
default:
|
||||||
return l.errorf("invalid rune %v", r)
|
return l.errorf("invalid rune %v", r)
|
||||||
}
|
}
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case '<':
|
case '<':
|
||||||
if next := l.read(); next == '=' {
|
if next := l.read(); next == '=' {
|
||||||
|
@ -177,6 +182,7 @@ func begin(l *Lexer) stateFn {
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emit(TokOpLt)
|
l.emit(TokOpLt)
|
||||||
}
|
}
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case '>':
|
case '>':
|
||||||
if next := l.read(); next == '=' {
|
if next := l.read(); next == '=' {
|
||||||
|
@ -185,6 +191,7 @@ func begin(l *Lexer) stateFn {
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emit(TokOpGt)
|
l.emit(TokOpGt)
|
||||||
}
|
}
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case '(':
|
case '(':
|
||||||
l.emit(TokParenOpen)
|
l.emit(TokParenOpen)
|
||||||
|
@ -231,15 +238,18 @@ func unquotedString(l *Lexer) stateFn {
|
||||||
case r == eof:
|
case r == eof:
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emitUnquotedString()
|
l.emitUnquotedString()
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case unicode.IsSpace(r):
|
case unicode.IsSpace(r):
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emitUnquotedString()
|
l.emitUnquotedString()
|
||||||
l.skip()
|
l.skip()
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
case isReserved(r):
|
case isReserved(r):
|
||||||
l.backup()
|
l.backup()
|
||||||
l.emitUnquotedString()
|
l.emitUnquotedString()
|
||||||
|
|
||||||
return begin
|
return begin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -251,6 +261,7 @@ func (l *Lexer) emitUnquotedString() {
|
||||||
l.emit(tokType)
|
l.emit(tokType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.emit(TokString)
|
l.emit(TokString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,5 +271,6 @@ func isReserved(r rune) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ package search
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestNextToken(t *testing.T) {
|
func TestNextToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input string
|
input string
|
||||||
|
|
|
@ -18,8 +18,10 @@ const (
|
||||||
precGroup
|
precGroup
|
||||||
)
|
)
|
||||||
|
|
||||||
type prefixParser func(*Parser) (Expression, error)
|
type (
|
||||||
type infixParser func(*Parser, Expression) (Expression, error)
|
prefixParser func(*Parser) (Expression, error)
|
||||||
|
infixParser func(*Parser, Expression) (Expression, error)
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
prefixParsers = map[TokenType]prefixParser{}
|
prefixParsers = map[TokenType]prefixParser{}
|
||||||
|
@ -77,7 +79,6 @@ func NewParser(l *Lexer) *Parser {
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseQuery(input string) (expr Expression, err error) {
|
func ParseQuery(input string) (expr Expression, err error) {
|
||||||
|
@ -91,18 +92,20 @@ func ParseQuery(input string) (expr Expression, err error) {
|
||||||
|
|
||||||
for !p.curTokenIs(TokEOF) {
|
for !p.curTokenIs(TokEOF) {
|
||||||
right, err := p.parseExpression(precLowest)
|
right, err := p.parseExpression(precLowest)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("search: could not parse expression: %v", err)
|
switch {
|
||||||
}
|
case err != nil:
|
||||||
if expr == nil {
|
return nil, fmt.Errorf("search: could not parse expression: %w", err)
|
||||||
|
case expr == nil:
|
||||||
expr = right
|
expr = right
|
||||||
} else {
|
default:
|
||||||
expr = &InfixExpression{
|
expr = &InfixExpression{
|
||||||
Operator: TokOpAnd,
|
Operator: TokOpAnd,
|
||||||
Left: expr,
|
Left: expr,
|
||||||
Right: right,
|
Right: right,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,18 +125,11 @@ func (p *Parser) peekTokenIs(t TokenType) bool {
|
||||||
return p.peek.Type == t
|
return p.peek.Type == t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) expectPeek(t TokenType) error {
|
|
||||||
if !p.peekTokenIs(t) {
|
|
||||||
return fmt.Errorf("expected next token to be %v, got %v", t, p.peek.Type)
|
|
||||||
}
|
|
||||||
p.nextToken()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Parser) curPrecedence() precedence {
|
func (p *Parser) curPrecedence() precedence {
|
||||||
if p, ok := tokenPrecedences[p.cur.Type]; ok {
|
if p, ok := tokenPrecedences[p.cur.Type]; ok {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
return precLowest
|
return precLowest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +137,7 @@ func (p *Parser) peekPrecedence() precedence {
|
||||||
if p, ok := tokenPrecedences[p.peek.Type]; ok {
|
if p, ok := tokenPrecedences[p.peek.Type]; ok {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
return precLowest
|
return precLowest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +149,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
|
||||||
|
|
||||||
expr, err := prefixParser(p)
|
expr, err := prefixParser(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse expression prefix: %v", err)
|
return nil, fmt.Errorf("could not parse expression prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for !p.peekTokenIs(eof) && prec < p.peekPrecedence() {
|
for !p.peekTokenIs(eof) && prec < p.peekPrecedence() {
|
||||||
|
@ -165,7 +162,7 @@ func (p *Parser) parseExpression(prec precedence) (Expression, error) {
|
||||||
|
|
||||||
expr, err = infixParser(p, expr)
|
expr, err = infixParser(p, expr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse infix expression: %v", err)
|
return nil, fmt.Errorf("could not parse infix expression: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,8 +178,9 @@ func parsePrefixExpression(p *Parser) (Expression, error) {
|
||||||
|
|
||||||
right, err := p.parseExpression(precPrefix)
|
right, err := p.parseExpression(precPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
|
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expr.Right = right
|
expr.Right = right
|
||||||
|
|
||||||
return expr, nil
|
return expr, nil
|
||||||
|
@ -199,8 +197,9 @@ func parseInfixExpression(p *Parser, left Expression) (Expression, error) {
|
||||||
|
|
||||||
right, err := p.parseExpression(prec)
|
right, err := p.parseExpression(prec)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse expression for right operand: %v", err)
|
return nil, fmt.Errorf("could not parse expression for right operand: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expr.Right = right
|
expr.Right = right
|
||||||
|
|
||||||
return expr, nil
|
return expr, nil
|
||||||
|
@ -215,17 +214,19 @@ func parseGroupedExpression(p *Parser) (Expression, error) {
|
||||||
|
|
||||||
expr, err := p.parseExpression(precLowest)
|
expr, err := p.parseExpression(precLowest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse grouped expression: %v", err)
|
return nil, fmt.Errorf("could not parse grouped expression: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for p.nextToken(); !p.curTokenIs(TokParenClose); p.nextToken() {
|
for p.nextToken(); !p.curTokenIs(TokParenClose); p.nextToken() {
|
||||||
if p.curTokenIs(TokEOF) {
|
if p.curTokenIs(TokEOF) {
|
||||||
return nil, fmt.Errorf("unexpected EOF: unmatched parentheses")
|
return nil, fmt.Errorf("unexpected EOF: unmatched parentheses")
|
||||||
}
|
}
|
||||||
|
|
||||||
right, err := p.parseExpression(precLowest)
|
right, err := p.parseExpression(precLowest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse expression: %v", err)
|
return nil, fmt.Errorf("could not parse expression: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expr = &InfixExpression{
|
expr = &InfixExpression{
|
||||||
Operator: TokOpAnd,
|
Operator: TokOpAnd,
|
||||||
Left: expr,
|
Left: expr,
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseQuery(t *testing.T) {
|
func TestParseQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input string
|
input string
|
||||||
|
@ -233,6 +235,8 @@ func TestParseQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertError(t *testing.T, exp, got error) {
|
func assertError(t *testing.T, exp, got error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case exp == nil && got != nil:
|
case exp == nil && got != nil:
|
||||||
t.Fatalf("expected: nil, got: %v", got)
|
t.Fatalf("expected: nil, got: %v", got)
|
||||||
|
|
Loading…
Reference in a new issue