dendrite/internal/sqlutil/sql.go
oliverpool a4db43e096
Don't overwrite global err before return (#1293)
Signed-off-by: Olivier Charvin <git@olivier.pfad.fr>
2020-08-25 13:11:52 +01:00

109 lines
3.1 KiB
Go

// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil
import (
"database/sql"
"errors"
"fmt"
"runtime"
)
// ErrUserExists is returned if a username already exists in the database.
var ErrUserExists = errors.New("Username already exists")
// A Transaction is something that can be committed or rolledback.
type Transaction interface {
// Commit the transaction
Commit() error
// Rollback the transaction.
Rollback() error
}
// EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
// You MUST check the error returned from this function to be sure that the transaction
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded {
return txn.Commit()
} else {
return txn.Rollback()
}
}
// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
// Designed to be used with defer (see EndTransaction otherwise).
func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) {
if e := EndTransaction(txn, succeeded); e != nil && *err == nil {
*err = e
}
}
// WithTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
}
succeeded := false
defer EndTransactionWithCheck(txn, &succeeded, &err)
err = fn(txn)
if err != nil {
return
}
succeeded = true
return
}
// TxStmt wraps an SQL stmt inside an optional transaction.
// If the transaction is nil then it returns the original statement that will
// run outside of a transaction.
// Otherwise returns a copy of the statement that will run inside the transaction.
func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
if transaction != nil {
statement = transaction.Stmt(statement)
}
return statement
}
// Hack of the century
func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0)
}
func QueryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}
func SQLiteDriverName() string {
if runtime.GOOS == "js" {
return "sqlite3_js"
}
return "sqlite3"
}