JDBC indeterminacy (#1507)

This PR adds an indeterminacy check to the JDBC verifiers.
This commit is contained in:
Cody Rose 2023-07-19 16:57:57 -04:00 committed by GitHub
parent 8fad5fff79
commit 20b7793828
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 190 additions and 79 deletions

View file

@ -86,7 +86,14 @@ matchLoop:
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
s.Verified = j.ping(ctx)
pingRes := j.ping(ctx)
s.Verified = pingRes.err == nil
// If there's a ping error that is marked as "determinate" we throw it away. We do this because this was the
// behavior before tri-state verification was introduced and preserving it allows us to gradually migrate
// detectors to use tri-state verification.
if pingRes.err != nil && !pingRes.determinate {
s.VerificationError = pingRes.err
}
// TODO: specialized redaction
}
@ -198,8 +205,13 @@ var supportedSubprotocols = map[string]func(string) (jdbc, error){
"sqlserver": parseSqlServer,
}
type pingResult struct {
err error
determinate bool
}
type jdbc interface {
ping(context.Context) bool
ping(context.Context) pingResult
}
func newJDBC(conn string) (jdbc, error) {
@ -220,13 +232,16 @@ func newJDBC(conn string) (jdbc, error) {
return parser(subname)
}
func ping(ctx context.Context, driverName string, candidateConns ...string) bool {
func ping(ctx context.Context, driverName string, isDeterminate func(error) bool, candidateConns ...string) pingResult {
var indeterminateErrors []error
for _, c := range candidateConns {
if err := pingErr(ctx, driverName, c); err == nil {
return true
err := pingErr(ctx, driverName, c)
if err == nil || isDeterminate(err) {
return pingResult{err, true}
}
indeterminateErrors = append(indeterminateErrors, err)
}
return false
return pingResult{errors.Join(indeterminateErrors...), false}
}
func pingErr(ctx context.Context, driverName, conn string) error {

View file

@ -3,9 +3,8 @@ package jdbc
import (
"context"
"errors"
"github.com/go-sql-driver/mysql"
"strings"
_ "github.com/go-sql-driver/mysql"
)
type mysqlJDBC struct {
@ -16,8 +15,8 @@ type mysqlJDBC struct {
params string
}
func (s *mysqlJDBC) ping(ctx context.Context) bool {
return ping(ctx, "mysql",
func (s *mysqlJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mysql", isMySQLErrorDeterminate,
s.conn,
buildMySQLConnectionString(s.host, s.database, s.userPass, s.params),
buildMySQLConnectionString(s.host, "", s.userPass, s.params))
@ -34,6 +33,22 @@ func buildMySQLConnectionString(host, database, userPass, params string) string
return conn
}
func isMySQLErrorDeterminate(err error) bool {
// MySQL error numbers from https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
if mySQLErr, isMySQLErr := err.(*mysql.MySQLError); isMySQLErr {
switch mySQLErr.Number {
case 1044:
// User access denied to a particular database
return false // "Indeterminate" so that other connection variations will be tried
case 1045:
// User access denied
return true
}
}
return false
}
func parseMySQL(subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
hostAndDB, params, _ := strings.Cut(subname, "?")

View file

@ -21,45 +21,54 @@ const (
)
func TestMySQL(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct {
input string
wantErr bool
wantPing bool
input string
want result
}{
{
input: "",
wantErr: true,
input: "",
want: result{parseErr: true},
},
{
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: true,
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/" + mysqlDatabase,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: false,
input: "//wrongUser:wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
want: result{pingOk: false, pingDeterminate: true},
},
{
input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
wantPing: false,
input: "//" + mysqlUser + ":wrongPass@tcp(127.0.0.1:3306)/" + mysqlDatabase,
want: result{pingOk: false, pingDeterminate: true},
},
{
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/",
wantPing: true,
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/",
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB",
wantPing: true,
input: "//" + mysqlUser + ":" + mysqlPass + "@tcp(127.0.0.1:3306)/wrongDB",
want: result{pingOk: true, pingDeterminate: true},
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
j, err := parseMySQL(tt.input)
if tt.wantErr {
assert.Error(t, err)
if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background()))
pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -4,9 +4,8 @@ import (
"context"
"errors"
"fmt"
"github.com/lib/pq"
"strings"
_ "github.com/lib/pq"
)
type postgresJDBC struct {
@ -14,12 +13,36 @@ type postgresJDBC struct {
params map[string]string
}
func (s *postgresJDBC) ping(ctx context.Context) bool {
return ping(ctx, "postgres",
s.conn,
"postgres://"+s.conn,
func (s *postgresJDBC) ping(ctx context.Context) pingResult {
// It is crucial that we try to build a connection string ourselves before using the one we found. This is because
// if the found connection string doesn't include a username, the driver will attempt to connect using the current
// user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In
// contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which
// actually has a chance of working.
return ping(ctx, "postgres", isPostgresErrorDeterminate,
buildPostgresConnectionString(s.params, true),
buildPostgresConnectionString(s.params, false))
buildPostgresConnectionString(s.params, false),
s.conn,
"postgres://"+s.conn)
}
func isPostgresErrorDeterminate(err error) bool {
// Postgres codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
if pqErr, isPostgresError := err.(*pq.Error); isPostgresError {
switch pqErr.Code {
case "28P01":
// Invalid username/password
return true
case "3D000":
// Unknown database
return false // "Indeterminate" so that other connection variations will be tried
case "3F000":
// Unknown schema
return false // "Indeterminate" so that other connection variations will be tried
}
}
return false
}
func joinKeyValues(m map[string]string, sep string) string {

View file

@ -20,49 +20,62 @@ const (
)
func TestPostgres(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct {
input string
wantErr bool
wantPing bool
input string
want result
}{
{
input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass,
wantPing: true,
input: "//localhost:5432/foo?sslmode=disable&password=" + postgresPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass,
wantPing: true,
input: "//localhost:5432/foo?sslmode=disable&user=" + postgresUser + "&password=" + postgresPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass,
wantPing: true,
input: "//localhost/foo?sslmode=disable&port=5432&password=" + postgresPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost:5432/foo?password=" + postgresPass,
wantPing: false,
input: "//localhost:5432/foo?password=" + postgresPass,
want: result{pingOk: false, pingDeterminate: false},
},
{
input: "//localhost:5432/foo?sslmode=disable&password=foo",
wantPing: false,
input: "//localhost:5432/foo?sslmode=disable&password=foo",
want: result{pingOk: false, pingDeterminate: true},
},
{
input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass,
wantPing: false,
input: "//localhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass,
want: result{pingOk: false, pingDeterminate: true},
},
{
input: "invalid",
wantErr: true,
input: "//badhost:5432/foo?sslmode=disable&user=foo&password=" + postgresPass,
want: result{pingOk: false, pingDeterminate: false},
},
{
input: "invalid",
want: result{parseErr: true},
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
j, err := parsePostgres(tt.input)
if tt.wantErr {
assert.Error(t, err)
if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background()))
pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -14,12 +14,18 @@ type sqliteJDBC struct {
testing bool
}
func (s *sqliteJDBC) ping(ctx context.Context) bool {
var cannotVerifySqliteError error = errors.New("sqlite credentials cannot be verified")
func (s *sqliteJDBC) ping(ctx context.Context) pingResult {
if !s.testing {
// sqlite is not a networked database, so we cannot verify
return false
return pingResult{cannotVerifySqliteError, true}
}
return ping(ctx, "sqlite3", s.filename)
return ping(ctx, "sqlite3", isSqliteErrorDeterminate, s.filename)
}
func isSqliteErrorDeterminate(err error) bool {
return true
}
func parseSqlite(subname string) (jdbc, error) {

View file

@ -39,7 +39,7 @@ func TestParseSqlite(t *testing.T) {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, j.ping(context.Background()))
assert.True(t, j.ping(context.Background()).err == nil)
}
})
}

View file

@ -3,9 +3,8 @@ package jdbc
import (
"context"
"errors"
mssql "github.com/denisenkom/go-mssqldb"
"strings"
_ "github.com/denisenkom/go-mssqldb"
)
type sqlServerJDBC struct {
@ -13,13 +12,27 @@ type sqlServerJDBC struct {
params map[string]string
}
func (s *sqlServerJDBC) ping(ctx context.Context) bool {
return ping(ctx, "mssql",
s.conn,
func (s *sqlServerJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mssql", isSqlServerErrorDeterminate,
joinKeyValues(s.params, ";"),
s.conn,
"sqlserver://"+s.conn)
}
func isSqlServerErrorDeterminate(err error) bool {
// Error numbers from https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors?view=sql-server-ver16
if mssqlError, isMssqlError := err.(mssql.Error); isMssqlError {
switch mssqlError.Number {
case 18456:
// Login failed
// This is a determinate failure iff we tried to use a real user
return mssqlError.Message != "login error: Login failed for user ''."
}
}
return false
}
func parseSqlServer(subname string) (jdbc, error) {
if !strings.HasPrefix(subname, "//") {
return nil, errors.New("expected connection to start with //")

View file

@ -21,33 +21,50 @@ const (
)
func TestSqlServer(t *testing.T) {
type result struct {
parseErr bool
pingOk bool
pingDeterminate bool
}
tests := []struct {
input string
wantErr bool
wantPing bool
input string
want result
}{
{
input: "",
wantErr: true,
input: "",
want: result{parseErr: true},
},
{
input: "//odbc:server=localhost;user id=sa;database=master;password=" + sqlServerPass,
wantPing: true,
input: "//server=localhost;user id=sa;database=master;password=" + sqlServerPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass,
wantPing: true,
input: "//server=badhost;user id=sa;database=master;password=" + sqlServerPass,
want: result{pingOk: false, pingDeterminate: false},
},
{
input: "//localhost;database=master;spring.datasource.password=" + sqlServerPass,
want: result{pingOk: true, pingDeterminate: true},
},
{
input: "//localhost;database=master;spring.datasource.password=badpassword",
want: result{pingOk: false, pingDeterminate: true},
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
j, err := parseSqlServer(tt.input)
if tt.wantErr {
assert.Error(t, err)
if err != nil {
got := result{parseErr: true}
assert.Equal(t, tt.want, got)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantPing, j.ping(context.Background()))
pr := j.ping(context.Background())
got := result{pingOk: pr.err == nil, pingDeterminate: pr.determinate}
assert.Equal(t, tt.want, got)
})
}
}