[analyze] Add Analyzer for Postgres (#3192)

* implement analyzer interface for postgres

* added unit test for postgres analyzer

* refactored code in postgres analyzer

* generate permissions for postgres analyzer

* renamed variable

* [chore] corrected the variable name.

* appended hostname to distinguish the resources.
updated the test.

---------

Co-authored-by: Abdul Basit <abasit@folio3.com>
This commit is contained in:
Abdul Basit 2024-09-07 00:42:55 +05:00 committed by GitHub
parent a43d451c4d
commit 93d09c78b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 492 additions and 1 deletions

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,141 @@
// Code generated by go generate; DO NOT EDIT.
package postgres
import "errors"
type Permission int
const (
Invalid Permission = iota
BypassRls Permission = iota
Connect Permission = iota
Create Permission = iota
CreateDb Permission = iota
CreateRole Permission = iota
Delete Permission = iota
InheritanceOfPrivs Permission = iota
Insert Permission = iota
Login Permission = iota
References Permission = iota
Replication Permission = iota
Select Permission = iota
Superuser Permission = iota
Temp Permission = iota
Trigger Permission = iota
Truncate Permission = iota
Update Permission = iota
)
var (
PermissionStrings = map[Permission]string{
BypassRls: "bypass_rls",
Connect: "connect",
Create: "create",
CreateDb: "create_db",
CreateRole: "create_role",
Delete: "delete",
InheritanceOfPrivs: "inheritance_of_privs",
Insert: "insert",
Login: "login",
References: "references",
Replication: "replication",
Select: "select",
Superuser: "superuser",
Temp: "temp",
Trigger: "trigger",
Truncate: "truncate",
Update: "update",
}
StringToPermission = map[string]Permission{
"bypass_rls": BypassRls,
"connect": Connect,
"create": Create,
"create_db": CreateDb,
"create_role": CreateRole,
"delete": Delete,
"inheritance_of_privs": InheritanceOfPrivs,
"insert": Insert,
"login": Login,
"references": References,
"replication": Replication,
"select": Select,
"superuser": Superuser,
"temp": Temp,
"trigger": Trigger,
"truncate": Truncate,
"update": Update,
}
PermissionIDs = map[Permission]int{
BypassRls: 1,
Connect: 2,
Create: 3,
CreateDb: 4,
CreateRole: 5,
Delete: 6,
InheritanceOfPrivs: 7,
Insert: 8,
Login: 9,
References: 10,
Replication: 11,
Select: 12,
Superuser: 13,
Temp: 14,
Trigger: 15,
Truncate: 16,
Update: 17,
}
IdToPermission = map[int]Permission{
1: BypassRls,
2: Connect,
3: Create,
4: CreateDb,
5: CreateRole,
6: Delete,
7: InheritanceOfPrivs,
8: Insert,
9: Login,
10: References,
11: Replication,
12: Select,
13: Superuser,
14: Temp,
15: Trigger,
16: Truncate,
17: Update,
}
)
// ToString converts a Permission enum to its string representation
func (p Permission) ToString() (string, error) {
if str, ok := PermissionStrings[p]; ok {
return str, nil
}
return "", errors.New("invalid permission")
}
// ToID converts a Permission enum to its ID
func (p Permission) ToID() (int, error) {
if id, ok := PermissionIDs[p]; ok {
return id, nil
}
return 0, errors.New("invalid permission")
}
// PermissionFromString converts a string representation to its Permission enum
func PermissionFromString(s string) (Permission, error) {
if p, ok := StringToPermission[s]; ok {
return p, nil
}
return 0, errors.New("invalid permission string")
}
// PermissionFromID converts an ID to its Permission enum
func PermissionFromID(id int) (Permission, error) {
if p, ok := IdToPermission[id]; ok {
return p, nil
}
return 0, errors.New("invalid permission ID")
}

View file

@ -0,0 +1,18 @@
permissions:
- bypass_rls
- connect
- create
- create_db
- create_role
- delete
- inheritance_of_privs
- insert
- login
- references
- replication
- select
- superuser
- temp
- trigger
- truncate
- update

View file

@ -1,3 +1,5 @@
//go:generate generate_permissions permissions.yaml permissions.go postgres
package postgres
import (
@ -14,8 +16,168 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/analyzers"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/pb/analyzerpb"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
var _ analyzers.Analyzer = (*Analyzer)(nil)
type Analyzer struct {
Cfg *config.Config
}
func (Analyzer) Type() analyzerpb.AnalyzerType { return analyzerpb.AnalyzerType_Postgres }
func (a Analyzer) Analyze(_ context.Context, credInfo map[string]string) (*analyzers.AnalyzerResult, error) {
uri, ok := credInfo["connection_string"]
if !ok {
return nil, errors.New("connection string not found in credInfo")
}
info, err := AnalyzePermissions(a.Cfg, uri)
if err != nil {
return nil, err
}
return secretInfoToAnalyzerResult(info), nil
}
func secretInfoToAnalyzerResult(info *SecretInfo) *analyzers.AnalyzerResult {
if info == nil {
return nil
}
result := analyzers.AnalyzerResult{
AnalyzerType: analyzerpb.AnalyzerType_Postgres,
Metadata: nil,
Bindings: []analyzers.Binding{},
}
// set user related bindings in result
userResource, userBindings := bakeUserBindings(info)
result.Bindings = append(result.Bindings, userBindings...)
// add user's database priviliges to bindings
dbNameToResourceMap, dbBindings := bakeDatabaseBindings(userResource, info)
result.Bindings = append(result.Bindings, dbBindings...)
// add user's table priviliges to bindings
tableBindings := bakeTableBindings(dbNameToResourceMap, info)
result.Bindings = append(result.Bindings, tableBindings...)
return &result
}
func bakeUserBindings(info *SecretInfo) (analyzers.Resource, []analyzers.Binding) {
userResource := analyzers.Resource{
Name: info.User,
FullyQualifiedName: info.Host + "/" + info.User,
Type: "user",
Metadata: map[string]any{
"role": info.Role,
},
}
var bindings []analyzers.Binding
for rolePriv, exists := range info.RolePrivs {
if exists {
bindings = append(bindings, analyzers.Binding{
Resource: userResource,
Permission: analyzers.Permission{
Value: rolePriv,
},
})
}
}
return userResource, bindings
}
func bakeDatabaseBindings(userResource analyzers.Resource, info *SecretInfo) (map[string]*analyzers.Resource, []analyzers.Binding) {
dbNameToResourceMap := map[string]*analyzers.Resource{}
dbBindings := []analyzers.Binding{}
for _, db := range info.DBs {
dbResource := analyzers.Resource{
Name: db.DatabaseName,
FullyQualifiedName: info.Host + "/" + db.DatabaseName,
Type: "database",
Metadata: map[string]any{
"owner": db.Owner,
},
Parent: &userResource,
}
// populate map to reference later for tables
dbNameToResourceMap[db.DatabaseName] = &dbResource
dbPriviliges := map[string]bool{
"connect": db.Connect,
"create": db.Create,
"temp": db.CreateTemp,
}
for priv, exists := range dbPriviliges {
if exists {
dbBindings = append(dbBindings, analyzers.Binding{
Resource: dbResource,
Permission: analyzers.Permission{
Value: priv,
},
})
}
}
}
return dbNameToResourceMap, dbBindings
}
func bakeTableBindings(dbNameToResourceMap map[string]*analyzers.Resource, info *SecretInfo) []analyzers.Binding {
var tableBindings []analyzers.Binding
for dbName, tableMap := range info.TablePrivs {
dbResource, ok := dbNameToResourceMap[dbName]
if !ok {
continue
}
for tableName, tableData := range tableMap {
tableResource := analyzers.Resource{
Name: tableName,
FullyQualifiedName: info.Host + "/" + dbResource.Name + "/" + tableName,
Type: "table",
Metadata: map[string]any{
"size": tableData.Size,
"rows": tableData.Rows,
},
Parent: dbResource,
}
tablePrivsMap := map[string]bool{
"select": tableData.Privs.Select,
"insert": tableData.Privs.Insert,
"update": tableData.Privs.Update,
"delete": tableData.Privs.Delete,
"truncate": tableData.Privs.Truncate,
"references": tableData.Privs.References,
"trigger": tableData.Privs.Trigger,
}
for priv, exists := range tablePrivsMap {
if exists {
tableBindings = append(tableBindings, analyzers.Binding{
Resource: tableResource,
Permission: analyzers.Permission{
Value: priv,
},
})
}
}
}
}
return tableBindings
}
type DBPrivs struct {
Connect bool
Create bool
@ -62,6 +224,7 @@ const (
var connStrPartPattern = regexp.MustCompile(`([[:alpha:]]+)='(.+?)' ?`)
type SecretInfo struct {
Host string
User string
Role string
RolePrivs map[string]bool
@ -132,6 +295,7 @@ func AnalyzePermissions(cfg *config.Config, connectionStr string) (*SecretInfo,
}
return &SecretInfo{
Host: params[pg_host],
User: currentUser,
Role: role,
RolePrivs: privs,

View file

@ -0,0 +1,164 @@
package postgres
import (
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"os/exec"
"sort"
"strings"
"testing"
"time"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/analyzers"
"github.com/trufflesecurity/trufflehog/v3/pkg/analyzer/config"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)
const (
postgresUser = "postgres"
postgresPass = "23201da=b56ca236f3dc6736c0f9afad"
postgresHost = "localhost"
postgresPort = "5434" // Do not use 5433, as local dev environments can use it for other things
defaultPort = "5432"
)
//go:embed expected_output.json
var expectedOutput []byte
func TestAnalyzer_Analyze(t *testing.T) {
if err := startPostgres(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
t.Fatalf("could not start local postgres: %v w/stderr:\n%s", err, string(exitErr.Stderr))
} else {
t.Fatalf("could not start local postgres: %v", err)
}
}
defer stopPostgres()
tests := []struct {
name string
connectionString string
want []byte // JSON string
wantErr bool
}{
{
name: "valid Postgres connection",
connectionString: fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres`, postgresUser, postgresPass, postgresHost, postgresPort),
want: expectedOutput,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := Analyzer{Cfg: &config.Config{}}
got, err := a.Analyze(context.Background(), map[string]string{"connection_string": tt.connectionString})
if (err != nil) != tt.wantErr {
t.Errorf("Analyzer.Analyze() error = %v, wantErr %v", err, tt.wantErr)
return
}
// bindings need to be in the same order to be comparable
sortBindings(got.Bindings)
// Marshal the actual result to JSON
gotJSON, err := json.Marshal(got)
if err != nil {
t.Fatalf("could not marshal got to JSON: %s", err)
}
// Parse the expected JSON string
var wantObj analyzers.AnalyzerResult
if err := json.Unmarshal(tt.want, &wantObj); err != nil {
t.Fatalf("could not unmarshal want JSON string: %s", err)
}
// bindings need to be in the same order to be comparable
sortBindings(wantObj.Bindings)
// Marshal the expected result to JSON (to normalize)
wantJSON, err := json.Marshal(wantObj)
if err != nil {
t.Fatalf("could not marshal want to JSON: %s", err)
}
// Compare the JSON strings
if string(gotJSON) != string(wantJSON) {
// Pretty-print both JSON strings for easier comparison
var gotIndented, wantIndented []byte
gotIndented, err = json.MarshalIndent(got, "", " ")
if err != nil {
t.Fatalf("could not marshal got to indented JSON: %s", err)
}
wantIndented, err = json.MarshalIndent(wantObj, "", " ")
if err != nil {
t.Fatalf("could not marshal want to indented JSON: %s", err)
}
t.Errorf("Analyzer.Analyze() = %s, want %s", gotIndented, wantIndented)
}
})
}
}
// Helper function to sort bindings
func sortBindings(bindings []analyzers.Binding) {
sort.SliceStable(bindings, func(i, j int) bool {
if bindings[i].Resource.Name == bindings[j].Resource.Name {
return bindings[i].Permission.Value < bindings[j].Permission.Value
}
return bindings[i].Resource.Name < bindings[j].Resource.Name
})
}
var postgresDockerHash string
func dockerLogLine(hash string, needle string) chan struct{} {
ch := make(chan struct{}, 1)
go func() {
for {
out, err := exec.Command("docker", "logs", hash).CombinedOutput()
if err != nil {
panic(err)
}
if strings.Contains(string(out), needle) {
ch <- struct{}{}
return
}
time.Sleep(1 * time.Second)
}
}()
return ch
}
func startPostgres() error {
cmd := exec.Command(
"docker", "run", "--rm", "-p", postgresPort+":"+defaultPort,
"-e", "POSTGRES_PASSWORD="+postgresPass,
"-e", "POSTGRES_USER="+postgresUser,
"-d", "postgres",
)
fmt.Println(cmd.String())
out, err := cmd.Output()
if err != nil {
return err
}
postgresDockerHash = string(bytes.TrimSpace(out))
select {
case <-dockerLogLine(postgresDockerHash, "PostgreSQL init process complete; ready for start up."):
return nil
case <-time.After(30 * time.Second):
stopPostgres()
return errors.New("timeout waiting for postgres database to be ready")
}
}
func stopPostgres() {
err := exec.Command("docker", "kill", postgresDockerHash).Run()
if err != nil {
fmt.Println("could not stop postgres container:", err)
}
}

View file

@ -131,6 +131,9 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
isVerified, verificationErr := verifyPostgres(params)
result.Verified = isVerified
result.SetVerificationError(verificationErr, password)
result.AnalysisInfo = map[string]string{
"connection_string": string(raw),
}
}
// We gather SSL information into ExtraData in case it's useful for later reporting.

View file

@ -337,7 +337,7 @@ func TestPostgres_FromChunk(t *testing.T) {
t.Fatalf("wantVerificationError = %v, verification error = %v", tt.want[i].VerificationError(), got[i].VerificationError())
}
}
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "verificationError")
ignoreOpts := cmpopts.IgnoreFields(detectors.Result{}, "verificationError", "AnalysisInfo")
if diff := cmp.Diff(got, tt.want, ignoreOpts); diff != "" {
t.Errorf("Postgres.FromData() %s diff: (-got +want)\n%s", tt.name, diff)
}