JDBC indeterminacy (#1507)

This PR adds an indeterminacy check to the JDBC verifiers.
This commit is contained in:
Cody Rose 2023-07-19 16:57:57 -04:00 committed by GitHub
parent 8fad5fff79
commit 20b7793828
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 190 additions and 79 deletions

View file

@ -86,7 +86,14 @@ matchLoop:
} }
ctx, cancel := context.WithTimeout(ctx, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
s.Verified = j.ping(ctx) pingRes := j.ping(ctx)
s.Verified = pingRes.err == nil
// If there's a ping error that is marked as "determinate" we throw it away. We do this because this was the
// behavior before tri-state verification was introduced and preserving it allows us to gradually migrate
// detectors to use tri-state verification.
if pingRes.err != nil && !pingRes.determinate {
s.VerificationError = pingRes.err
}
// TODO: specialized redaction // TODO: specialized redaction
} }
@ -198,8 +205,13 @@ var supportedSubprotocols = map[string]func(string) (jdbc, error){
"sqlserver": parseSqlServer, "sqlserver": parseSqlServer,
} }
type pingResult struct {
err error
determinate bool
}
type jdbc interface { type jdbc interface {
ping(context.Context) bool ping(context.Context) pingResult
} }
func newJDBC(conn string) (jdbc, error) { func newJDBC(conn string) (jdbc, error) {
@ -220,13 +232,16 @@ func newJDBC(conn string) (jdbc, error) {
return parser(subname) return parser(subname)
} }
func ping(ctx context.Context, driverName string, candidateConns ...string) bool { func ping(ctx context.Context, driverName string, isDeterminate func(error) bool, candidateConns ...string) pingResult {
var indeterminateErrors []error
for _, c := range candidateConns { for _, c := range candidateConns {
if err := pingErr(ctx, driverName, c); err == nil { err := pingErr(ctx, driverName, c)
return true if err == nil || isDeterminate(err) {
return pingResult{err, true}
} }
indeterminateErrors = append(indeterminateErrors, err)
} }
return false return pingResult{errors.Join(indeterminateErrors...), false}
} }
func pingErr(ctx context.Context, driverName, conn string) error { func pingErr(ctx context.Context, driverName, conn string) error {

View file

@ -3,9 +3,8 @@ package jdbc
import ( import (
"context" "context"
"errors" "errors"
"github.com/go-sql-driver/mysql"
"strings" "strings"
_ "github.com/go-sql-driver/mysql"
) )
type mysqlJDBC struct { type mysqlJDBC struct {
@ -16,8 +15,8 @@ type mysqlJDBC struct {
params string params string
} }
func (s *mysqlJDBC) ping(ctx context.Context) bool { func (s *mysqlJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mysql", return ping(ctx, "mysql", isMySQLErrorDeterminate,
s.conn, s.conn,
buildMySQLConnectionString(s.host, s.database, s.userPass, s.params), buildMySQLConnectionString(s.host, s.database, s.userPass, s.params),
buildMySQLConnectionString(s.host, "", s.userPass, s.params)) buildMySQLConnectionString(s.host, "", s.userPass, s.params))
@ -34,6 +33,22 @@ func buildMySQLConnectionString(host, database, userPass, params string) string
return conn return conn
} }
func isMySQLErrorDeterminate(err error) bool {
// MySQL error numbers from https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
if mySQLErr, isMySQLErr := err.(*mysql.MySQLError); isMySQLErr {
switch mySQLErr.Number {
case 1044:
// User access denied to a particular database
return false // "Indeterminate" so that other connection variations will be tried
case 1045:
// User access denied
return true
}
}
return false
}
func parseMySQL(subname string) (jdbc, error) { func parseMySQL(subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]] // expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
hostAndDB, params, _ := strings.Cut(subname, "?") hostAndDB, params, _ := strings.Cut(subname, "?")

View file

@ -21,45 +21,54 @@ const (
) )
func TestMySQL(t *testing.T) { func TestMySQL(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct { tests := []struct {
input string input string
wantErr bool want result
wantPing bool
}{ }{
{ {
input: "", input: "",
wantErr: true, want: result{parseErr: true},
}, },
{ {
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase, input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: false, want: result{pingOk: false, pingDeterminate: true},
}, },
{ {
input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase, input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: false, want: result{pingOk: false, pingDeterminate: true},
}, },
{ {
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/", input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/",
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB", input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB",
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) { t.Run(tt.input, func(t *testing.T) {
j, err := parseMySQL(tt.input) j, err := parseMySQL(tt.input)
if tt.wantErr {
assert.Error(t, err) if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return return
} }
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background())) pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
}) })
} }
} }

View file

@ -4,9 +4,8 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/lib/pq"
"strings" "strings"
_ "github.com/lib/pq"
) )
type postgresJDBC struct { type postgresJDBC struct {
@ -14,12 +13,36 @@ type postgresJDBC struct {
params map[string]string params map[string]string
} }
func (s *postgresJDBC) ping(ctx context.Context) bool { func (s *postgresJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "postgres", // It is crucial that we try to build a connection string ourselves before using the one we found. This is because
s.conn, // if the found connection string doesn't include a username, the driver will attempt to connect using the current
"postgres://"+s.conn, // user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In
// contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which
// actually has a chance of working.
return ping(ctx, "postgres", isPostgresErrorDeterminate,
buildPostgresConnectionString(s.params, true), buildPostgresConnectionString(s.params, true),
buildPostgresConnectionString(s.params, false)) buildPostgresConnectionString(s.params, false),
s.conn,
"postgres://"+s.conn)
}
func isPostgresErrorDeterminate(err error) bool {
// Postgres codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
if pqErr, isPostgresError := err.(*pq.Error); isPostgresError {
switch pqErr.Code {
case "28P01":
// Invalid username/password
return true
case "3D000":
// Unknown database
return false // "Indeterminate" so that other connection variations will be tried
case "3F000":
// Unknown schema
return false // "Indeterminate" so that other connection variations will be tried
}
}
return false
} }
func joinKeyValues(m map[string]string, sep string) string { func joinKeyValues(m map[string]string, sep string) string {

View file

@ -20,49 +20,62 @@ const (
) )
func TestPostgres(t *testing.T) { func TestPostgres(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct { tests := []struct {
input string input string
wantErr bool want result
wantPing bool
}{ }{
{ {
input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass, input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass,
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass, input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass,
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass, input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass,
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//localhost:5432/foo?password=" + postgresPass, input: "//localhost:5432/foo?password=" + postgresPass,
wantPing: false, want: result{pingOk: false, pingDeterminate: false},
}, },
{ {
input: "//localhost:5432/foo?sslmode=disable&password=foo", input: "//localhost:5432/foo?sslmode=disable&password=foo",
wantPing: false, want: result{pingOk: false, pingDeterminate: true},
}, },
{ {
input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass, input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass,
wantPing: false, want: result{pingOk: false, pingDeterminate: true},
}, },
{ {
input: "invalid", input: "//badhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass,
wantErr: true, want: result{pingOk: false, pingDeterminate: false},
},
{
input: "invalid",
want: result{parseErr: true},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) { t.Run(tt.input, func(t *testing.T) {
j, err := parsePostgres(tt.input) j, err := parsePostgres(tt.input)
if tt.wantErr {
assert.Error(t, err) if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return return
} }
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background())) pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
}) })
} }
} }

View file

@ -14,12 +14,18 @@ type sqliteJDBC struct {
testing bool testing bool
} }
func (s *sqliteJDBC) ping(ctx context.Context) bool { var cannotVerifySqliteError error = errors.New("sqlite credentials cannot be verified")
func (s *sqliteJDBC) ping(ctx context.Context) pingResult {
if !s.testing { if !s.testing {
// sqlite is not a networked database, so we cannot verify // sqlite is not a networked database, so we cannot verify
return false return pingResult{cannotVerifySqliteError, true}
} }
return ping(ctx, "sqlite3", s.filename) return ping(ctx, "sqlite3", isSqliteErrorDeterminate, s.filename)
}
func isSqliteErrorDeterminate(err error) bool {
return true
} }
func parseSqlite(subname string) (jdbc, error) { func parseSqlite(subname string) (jdbc, error) {

View file

@ -39,7 +39,7 @@ func TestParseSqlite(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, j.ping(context.Background())) assert.True(t, j.ping(context.Background()).err == nil)
} }
}) })
} }

View file

@ -3,9 +3,8 @@ package jdbc
import ( import (
"context" "context"
"errors" "errors"
mssql "github.com/denisenkom/go-mssqldb"
"strings" "strings"
_ "github.com/denisenkom/go-mssqldb"
) )
type sqlServerJDBC struct { type sqlServerJDBC struct {
@ -13,13 +12,27 @@ type sqlServerJDBC struct {
params map[string]string params map[string]string
} }
func (s *sqlServerJDBC) ping(ctx context.Context) bool { func (s *sqlServerJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mssql", return ping(ctx, "mssql", isSqlServerErrorDeterminate,
s.conn,
joinKeyValues(s.params, ";"), joinKeyValues(s.params, ";"),
s.conn,
"sqlserver://"+s.conn) "sqlserver://"+s.conn)
} }
func isSqlServerErrorDeterminate(err error) bool {
// Error numbers from https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors?view=sql-server-ver16
if mssqlError, isMssqlError := err.(mssql.Error); isMssqlError {
switch mssqlError.Number {
case 18456:
// Login failed
// This is a determinate failure iff we tried to use a real user
return mssqlError.Message != "login error: Login failed for user ''."
}
}
return false
}
func parseSqlServer(subname string) (jdbc, error) { func parseSqlServer(subname string) (jdbc, error) {
if !strings.HasPrefix(subname, "//") { if !strings.HasPrefix(subname, "//") {
return nil, errors.New("expected connection to start with //") return nil, errors.New("expected connection to start with //")

View file

@ -21,33 +21,50 @@ const (
) )
func TestSqlServer(t *testing.T) { func TestSqlServer(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct { tests := []struct {
input string input string
wantErr bool want result
wantPing bool
}{ }{
{ {
input: "", input: "",
wantErr: true, want: result{parseErr: true},
}, },
{ {
input: "//odbc:server=localhost;user id=sa;database=master;password=" + sqlServerPass, input: "//server=localhost;user id=sa;database=master;password=" + sqlServerPass,
wantPing: true, want: result{pingOk: true, pingDeterminate: true},
}, },
{ {
input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass, input: "//server=badhost;user id=sa;database=master;password=" + sqlServerPass,
wantPing: true, want: result{pingOk: false, pingDeterminate: false},
},
{
input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost;database=master;spring.datasource.password=badpassword",
want: result{pingOk: false, pingDeterminate: true},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) { t.Run(tt.input, func(t *testing.T) {
j, err := parseSqlServer(tt.input) j, err := parseSqlServer(tt.input)
if tt.wantErr {
assert.Error(t, err) if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return return
} }
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background())) pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
}) })
} }
} }