mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2025-01-01 07:28:46 +00:00
231 lines
7 KiB
Go
231 lines
7 KiB
Go
|
// GoToSocial
|
||
|
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||
|
//
|
||
|
// This program is free software: you can redistribute it and/or modify
|
||
|
// it under the terms of the GNU Affero General Public License as published by
|
||
|
// the Free Software Foundation, either version 3 of the License, or
|
||
|
// (at your option) any later version.
|
||
|
//
|
||
|
// This program is distributed in the hope that it will be useful,
|
||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
// GNU Affero General Public License for more details.
|
||
|
//
|
||
|
// You should have received a copy of the GNU Affero General Public License
|
||
|
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||
|
|
||
|
package bundb
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||
|
"github.com/uptrace/bun"
|
||
|
"github.com/uptrace/bun/dialect"
|
||
|
)
|
||
|
|
||
|
// UpsertQuery is a wrapper around an insert query that can update if an insert fails.
|
||
|
// Doesn't implement the full set of Bun query methods, but we can add more if we need them.
|
||
|
// See https://bun.uptrace.dev/guide/query-insert.html#upsert
|
||
|
type UpsertQuery struct {
|
||
|
db bun.IDB
|
||
|
model interface{}
|
||
|
constraints []string
|
||
|
columns []string
|
||
|
}
|
||
|
|
||
|
func NewUpsert(idb bun.IDB) *UpsertQuery {
|
||
|
// note: passing in rawtx as conn iface so no double query-hook
|
||
|
// firing when passed through the bun.Tx.Query___() functions.
|
||
|
return &UpsertQuery{db: idb}
|
||
|
}
|
||
|
|
||
|
// Model sets the model or models to upsert.
|
||
|
func (u *UpsertQuery) Model(model interface{}) *UpsertQuery {
|
||
|
u.model = model
|
||
|
return u
|
||
|
}
|
||
|
|
||
|
// Constraint sets the columns or indexes that are used to check for conflicts.
|
||
|
// This is required.
|
||
|
func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery {
|
||
|
u.constraints = constraints
|
||
|
return u
|
||
|
}
|
||
|
|
||
|
// Column sets the columns to update if an insert does't happen.
|
||
|
// If empty, all columns not being used for constraints will be updated.
|
||
|
// Cannot overlap with Constraint.
|
||
|
func (u *UpsertQuery) Column(columns ...string) *UpsertQuery {
|
||
|
u.columns = columns
|
||
|
return u
|
||
|
}
|
||
|
|
||
|
// insertDialect errors if we're using a dialect in which we don't know how to upsert.
|
||
|
func (u *UpsertQuery) insertDialect() error {
|
||
|
dialectName := u.db.Dialect().Name()
|
||
|
switch dialectName {
|
||
|
case dialect.PG, dialect.SQLite:
|
||
|
return nil
|
||
|
default:
|
||
|
// FUTURE: MySQL has its own variation on upserts, but the syntax is different.
|
||
|
return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// insertConstraints checks that we have constraints and returns them.
|
||
|
func (u *UpsertQuery) insertConstraints() ([]string, error) {
|
||
|
if len(u.constraints) == 0 {
|
||
|
return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided")
|
||
|
}
|
||
|
return u.constraints, nil
|
||
|
}
|
||
|
|
||
|
// insertColumns returns the non-constraint columns we'll be updating.
|
||
|
func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) {
|
||
|
// Constraints as a set.
|
||
|
constraintSet := make(map[string]struct{}, len(constraints))
|
||
|
for _, constraint := range constraints {
|
||
|
constraintSet[constraint] = struct{}{}
|
||
|
}
|
||
|
|
||
|
var columns []string
|
||
|
var err error
|
||
|
if len(u.columns) == 0 {
|
||
|
columns, err = u.insertColumnsDefault(constraintSet)
|
||
|
} else {
|
||
|
columns, err = u.insertColumnsSpecified(constraintSet)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(columns) == 0 {
|
||
|
return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting")
|
||
|
}
|
||
|
|
||
|
return columns, nil
|
||
|
}
|
||
|
|
||
|
// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking.
|
||
|
func hasElem(modelType reflect.Type) bool {
|
||
|
switch modelType.Kind() {
|
||
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// insertColumnsDefault returns all non-constraint columns from the model schema.
|
||
|
func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) {
|
||
|
// Get underlying struct type.
|
||
|
modelType := reflect.TypeOf(u.model)
|
||
|
for hasElem(modelType) {
|
||
|
modelType = modelType.Elem()
|
||
|
}
|
||
|
|
||
|
table := u.db.Dialect().Tables().Get(modelType)
|
||
|
if table == nil {
|
||
|
return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model)
|
||
|
}
|
||
|
|
||
|
columns := make([]string, 0, len(u.columns))
|
||
|
for _, field := range table.Fields {
|
||
|
column := field.Name
|
||
|
if _, overlaps := constraintSet[column]; !overlaps {
|
||
|
columns = append(columns, column)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return columns, nil
|
||
|
}
|
||
|
|
||
|
// insertColumnsSpecified ensures constraints and specified columns to update don't overlap.
|
||
|
func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) {
|
||
|
overlapping := make([]string, 0, min(len(u.constraints), len(u.columns)))
|
||
|
for _, column := range u.columns {
|
||
|
if _, overlaps := constraintSet[column]; overlaps {
|
||
|
overlapping = append(overlapping, column)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(overlapping) > 0 {
|
||
|
return nil, gtserror.Newf(
|
||
|
"UpsertQuery: the following columns can't be used for both constraints and columns to update: %s",
|
||
|
strings.Join(overlapping, ", "),
|
||
|
)
|
||
|
}
|
||
|
|
||
|
return u.columns, nil
|
||
|
}
|
||
|
|
||
|
// insert tries to create a Bun insert query from an upsert query.
|
||
|
func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
|
||
|
var err error
|
||
|
|
||
|
err = u.insertDialect()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
constraints, err := u.insertConstraints()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
columns, err := u.insertColumns(constraints)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Build the parts of the query that need us to generate SQL.
|
||
|
constraintIDPlaceholders := make([]string, 0, len(constraints))
|
||
|
constraintIDs := make([]interface{}, 0, len(constraints))
|
||
|
for _, constraint := range constraints {
|
||
|
constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
|
||
|
constraintIDs = append(constraintIDs, bun.Ident(constraint))
|
||
|
}
|
||
|
onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
|
||
|
|
||
|
setClauses := make([]string, 0, len(columns))
|
||
|
setIDs := make([]interface{}, 0, 2*len(columns))
|
||
|
for _, column := range columns {
|
||
|
// "excluded" is a special table that contains only the row involved in a conflict.
|
||
|
setClauses = append(setClauses, "? = excluded.?")
|
||
|
setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
|
||
|
}
|
||
|
setSQL := strings.Join(setClauses, ", ")
|
||
|
|
||
|
insertQuery := u.db.
|
||
|
NewInsert().
|
||
|
Model(u.model).
|
||
|
On(onSQL, constraintIDs...).
|
||
|
Set(setSQL, setIDs...)
|
||
|
|
||
|
return insertQuery, nil
|
||
|
}
|
||
|
|
||
|
// Exec builds a Bun insert query from the upsert query, and executes it.
|
||
|
func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
|
||
|
insertQuery, err := u.insertQuery()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return insertQuery.Exec(ctx, dest...)
|
||
|
}
|
||
|
|
||
|
// Scan builds a Bun insert query from the upsert query, and scans it.
|
||
|
func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
|
||
|
insertQuery, err := u.insertQuery()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return insertQuery.Scan(ctx, dest...)
|
||
|
}
|