mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2024-12-27 13:13:10 +00:00
1348 lines
29 KiB
Go
1348 lines
29 KiB
Go
package bun
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/uptrace/bun/dialect/feature"
|
|
"github.com/uptrace/bun/internal"
|
|
"github.com/uptrace/bun/schema"
|
|
)
|
|
|
|
const (
|
|
forceDeleteFlag internal.Flag = 1 << iota
|
|
deletedFlag
|
|
allWithDeletedFlag
|
|
)
|
|
|
|
type withQuery struct {
|
|
name string
|
|
query schema.QueryAppender
|
|
recursive bool
|
|
}
|
|
|
|
// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx.
|
|
type IConn interface {
|
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
|
}
|
|
|
|
var (
|
|
_ IConn = (*sql.DB)(nil)
|
|
_ IConn = (*sql.Conn)(nil)
|
|
_ IConn = (*sql.Tx)(nil)
|
|
_ IConn = (*DB)(nil)
|
|
_ IConn = (*Conn)(nil)
|
|
_ IConn = (*Tx)(nil)
|
|
)
|
|
|
|
// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx.
|
|
type IDB interface {
|
|
IConn
|
|
Dialect() schema.Dialect
|
|
|
|
NewValues(model interface{}) *ValuesQuery
|
|
NewSelect() *SelectQuery
|
|
NewInsert() *InsertQuery
|
|
NewUpdate() *UpdateQuery
|
|
NewDelete() *DeleteQuery
|
|
NewRaw(query string, args ...interface{}) *RawQuery
|
|
NewCreateTable() *CreateTableQuery
|
|
NewDropTable() *DropTableQuery
|
|
NewCreateIndex() *CreateIndexQuery
|
|
NewDropIndex() *DropIndexQuery
|
|
NewTruncateTable() *TruncateTableQuery
|
|
NewAddColumn() *AddColumnQuery
|
|
NewDropColumn() *DropColumnQuery
|
|
|
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
|
|
RunInTx(ctx context.Context, opts *sql.TxOptions, f func(ctx context.Context, tx Tx) error) error
|
|
}
|
|
|
|
var (
|
|
_ IDB = (*DB)(nil)
|
|
_ IDB = (*Conn)(nil)
|
|
_ IDB = (*Tx)(nil)
|
|
)
|
|
|
|
// QueryBuilder is used for common query methods
|
|
type QueryBuilder interface {
|
|
Query
|
|
Where(query string, args ...interface{}) QueryBuilder
|
|
WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder
|
|
WhereOr(query string, args ...interface{}) QueryBuilder
|
|
WhereDeleted() QueryBuilder
|
|
WhereAllWithDeleted() QueryBuilder
|
|
WherePK(cols ...string) QueryBuilder
|
|
Unwrap() interface{}
|
|
}
|
|
|
|
var (
|
|
_ QueryBuilder = (*selectQueryBuilder)(nil)
|
|
_ QueryBuilder = (*updateQueryBuilder)(nil)
|
|
_ QueryBuilder = (*deleteQueryBuilder)(nil)
|
|
)
|
|
|
|
type baseQuery struct {
|
|
db *DB
|
|
conn IConn
|
|
|
|
model Model
|
|
err error
|
|
|
|
tableModel TableModel
|
|
table *schema.Table
|
|
|
|
with []withQuery
|
|
modelTableName schema.QueryWithArgs
|
|
tables []schema.QueryWithArgs
|
|
columns []schema.QueryWithArgs
|
|
|
|
flags internal.Flag
|
|
}
|
|
|
|
func (q *baseQuery) DB() *DB {
|
|
return q.db
|
|
}
|
|
|
|
func (q *baseQuery) GetConn() IConn {
|
|
return q.conn
|
|
}
|
|
|
|
func (q *baseQuery) GetModel() Model {
|
|
return q.model
|
|
}
|
|
|
|
func (q *baseQuery) GetTableName() string {
|
|
if q.table != nil {
|
|
return q.table.Name
|
|
}
|
|
|
|
for _, wq := range q.with {
|
|
if v, ok := wq.query.(Query); ok {
|
|
if model := v.GetModel(); model != nil {
|
|
return v.GetTableName()
|
|
}
|
|
}
|
|
}
|
|
|
|
if q.modelTableName.Query != "" {
|
|
return q.modelTableName.Query
|
|
}
|
|
|
|
if len(q.tables) > 0 {
|
|
b, _ := q.tables[0].AppendQuery(q.db.fmter, nil)
|
|
if len(b) < 64 {
|
|
return string(b)
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (q *baseQuery) setConn(db IConn) {
|
|
// Unwrap Bun wrappers to not call query hooks twice.
|
|
switch db := db.(type) {
|
|
case *DB:
|
|
q.conn = db.DB
|
|
case Conn:
|
|
q.conn = db.Conn
|
|
case Tx:
|
|
q.conn = db.Tx
|
|
default:
|
|
q.conn = db
|
|
}
|
|
}
|
|
|
|
func (q *baseQuery) setModel(modeli interface{}) {
|
|
model, err := newSingleModel(q.db, modeli)
|
|
if err != nil {
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
|
|
q.model = model
|
|
if tm, ok := model.(TableModel); ok {
|
|
q.tableModel = tm
|
|
q.table = tm.Table()
|
|
}
|
|
}
|
|
|
|
func (q *baseQuery) setErr(err error) {
|
|
if q.err == nil {
|
|
q.err = err
|
|
}
|
|
}
|
|
|
|
func (q *baseQuery) getModel(dest []interface{}) (Model, error) {
|
|
if len(dest) > 0 {
|
|
return newModel(q.db, dest)
|
|
}
|
|
if q.model != nil {
|
|
return q.model, nil
|
|
}
|
|
return nil, errNilModel
|
|
}
|
|
|
|
func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error {
|
|
if q.tableModel != nil {
|
|
return q.tableModel.BeforeAppendModel(ctx, query)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (q *baseQuery) hasFeature(feature feature.Feature) bool {
|
|
return q.db.features.Has(feature)
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) checkSoftDelete() error {
|
|
if q.table == nil {
|
|
return errors.New("bun: can't use soft deletes without a table")
|
|
}
|
|
if q.table.SoftDeleteField == nil {
|
|
return fmt.Errorf("%s does not have a soft delete field", q.table)
|
|
}
|
|
if q.tableModel == nil {
|
|
return errors.New("bun: can't use soft deletes without a table model")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models.
|
|
func (q *baseQuery) whereDeleted() {
|
|
if err := q.checkSoftDelete(); err != nil {
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
q.flags = q.flags.Set(deletedFlag)
|
|
q.flags = q.flags.Remove(allWithDeletedFlag)
|
|
}
|
|
|
|
// AllWithDeleted changes query to return all rows including soft deleted ones.
|
|
func (q *baseQuery) whereAllWithDeleted() {
|
|
if err := q.checkSoftDelete(); err != nil {
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
q.flags = q.flags.Set(allWithDeletedFlag).Remove(deletedFlag)
|
|
}
|
|
|
|
func (q *baseQuery) isSoftDelete() bool {
|
|
if q.table != nil {
|
|
return q.table.SoftDeleteField != nil &&
|
|
!q.flags.Has(allWithDeletedFlag) &&
|
|
(!q.flags.Has(forceDeleteFlag) || q.flags.Has(deletedFlag))
|
|
}
|
|
return false
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) {
|
|
q.with = append(q.with, withQuery{
|
|
name: name,
|
|
query: query,
|
|
recursive: recursive,
|
|
})
|
|
}
|
|
|
|
func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) {
|
|
if len(q.with) == 0 {
|
|
return b, nil
|
|
}
|
|
|
|
b = append(b, "WITH "...)
|
|
for i, with := range q.with {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
|
|
if with.recursive {
|
|
b = append(b, "RECURSIVE "...)
|
|
}
|
|
|
|
b, err = q.appendCTE(fmter, b, with)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
b = append(b, ' ')
|
|
return b, nil
|
|
}
|
|
|
|
func (q *baseQuery) appendCTE(
|
|
fmter schema.Formatter, b []byte, cte withQuery,
|
|
) (_ []byte, err error) {
|
|
if !fmter.Dialect().Features().Has(feature.WithValues) {
|
|
if values, ok := cte.query.(*ValuesQuery); ok {
|
|
return q.appendSelectFromValues(fmter, b, cte, values)
|
|
}
|
|
}
|
|
|
|
b = fmter.AppendIdent(b, cte.name)
|
|
|
|
if q, ok := cte.query.(schema.ColumnsAppender); ok {
|
|
b = append(b, " ("...)
|
|
b, err = q.AppendColumns(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = append(b, ")"...)
|
|
}
|
|
|
|
b = append(b, " AS ("...)
|
|
|
|
b, err = cte.query.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b = append(b, ")"...)
|
|
return b, nil
|
|
}
|
|
|
|
func (q *baseQuery) appendSelectFromValues(
|
|
fmter schema.Formatter, b []byte, cte withQuery, values *ValuesQuery,
|
|
) (_ []byte, err error) {
|
|
b = fmter.AppendIdent(b, cte.name)
|
|
b = append(b, " AS (SELECT * FROM ("...)
|
|
|
|
b, err = cte.query.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b = append(b, ") AS t"...)
|
|
if q, ok := cte.query.(schema.ColumnsAppender); ok {
|
|
b = append(b, " ("...)
|
|
b, err = q.AppendColumns(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = append(b, ")"...)
|
|
}
|
|
b = append(b, ")"...)
|
|
|
|
return b, nil
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) addTable(table schema.QueryWithArgs) {
|
|
q.tables = append(q.tables, table)
|
|
}
|
|
|
|
func (q *baseQuery) addColumn(column schema.QueryWithArgs) {
|
|
q.columns = append(q.columns, column)
|
|
}
|
|
|
|
func (q *baseQuery) excludeColumn(columns []string) {
|
|
if q.table == nil {
|
|
q.setErr(errNilModel)
|
|
return
|
|
}
|
|
|
|
if q.columns == nil {
|
|
for _, f := range q.table.Fields {
|
|
q.columns = append(q.columns, schema.UnsafeIdent(f.Name))
|
|
}
|
|
}
|
|
|
|
if len(columns) == 1 && columns[0] == "*" {
|
|
q.columns = make([]schema.QueryWithArgs, 0)
|
|
return
|
|
}
|
|
|
|
for _, column := range columns {
|
|
if !q._excludeColumn(column) {
|
|
q.setErr(fmt.Errorf("bun: can't find column=%q", column))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (q *baseQuery) _excludeColumn(column string) bool {
|
|
for i, col := range q.columns {
|
|
if col.Args == nil && col.Query == column {
|
|
q.columns = append(q.columns[:i], q.columns[i+1:]...)
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) modelHasTableName() bool {
|
|
if !q.modelTableName.IsZero() {
|
|
return q.modelTableName.Query != ""
|
|
}
|
|
return q.table != nil
|
|
}
|
|
|
|
func (q *baseQuery) hasTables() bool {
|
|
return q.modelHasTableName() || len(q.tables) > 0
|
|
}
|
|
|
|
func (q *baseQuery) appendTables(
|
|
fmter schema.Formatter, b []byte,
|
|
) (_ []byte, err error) {
|
|
return q._appendTables(fmter, b, false)
|
|
}
|
|
|
|
func (q *baseQuery) appendTablesWithAlias(
|
|
fmter schema.Formatter, b []byte,
|
|
) (_ []byte, err error) {
|
|
return q._appendTables(fmter, b, true)
|
|
}
|
|
|
|
func (q *baseQuery) _appendTables(
|
|
fmter schema.Formatter, b []byte, withAlias bool,
|
|
) (_ []byte, err error) {
|
|
startLen := len(b)
|
|
|
|
if q.modelHasTableName() {
|
|
if !q.modelTableName.IsZero() {
|
|
b, err = q.modelTableName.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects))
|
|
if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects {
|
|
b = append(b, " AS "...)
|
|
b = append(b, q.table.SQLAlias...)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, table := range q.tables {
|
|
if len(b) > startLen {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = table.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) {
|
|
return q._appendFirstTable(fmter, b, false)
|
|
}
|
|
|
|
func (q *baseQuery) appendFirstTableWithAlias(
|
|
fmter schema.Formatter, b []byte,
|
|
) ([]byte, error) {
|
|
return q._appendFirstTable(fmter, b, true)
|
|
}
|
|
|
|
func (q *baseQuery) _appendFirstTable(
|
|
fmter schema.Formatter, b []byte, withAlias bool,
|
|
) ([]byte, error) {
|
|
if !q.modelTableName.IsZero() {
|
|
return q.modelTableName.AppendQuery(fmter, b)
|
|
}
|
|
|
|
if q.table != nil {
|
|
b = fmter.AppendQuery(b, string(q.table.SQLName))
|
|
if withAlias {
|
|
b = append(b, " AS "...)
|
|
b = append(b, q.table.SQLAlias...)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
if len(q.tables) > 0 {
|
|
return q.tables[0].AppendQuery(fmter, b)
|
|
}
|
|
|
|
return nil, errors.New("bun: query does not have a table")
|
|
}
|
|
|
|
func (q *baseQuery) hasMultiTables() bool {
|
|
if q.modelHasTableName() {
|
|
return len(q.tables) >= 1
|
|
}
|
|
return len(q.tables) >= 2
|
|
}
|
|
|
|
func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
|
|
tables := q.tables
|
|
if !q.modelHasTableName() {
|
|
tables = tables[1:]
|
|
}
|
|
for i, table := range tables {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = table.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
|
|
for i, f := range q.columns {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = f.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func (q *baseQuery) getFields() ([]*schema.Field, error) {
|
|
if len(q.columns) == 0 {
|
|
if q.table == nil {
|
|
return nil, errNilModel
|
|
}
|
|
return q.table.Fields, nil
|
|
}
|
|
return q._getFields(false)
|
|
}
|
|
|
|
func (q *baseQuery) getDataFields() ([]*schema.Field, error) {
|
|
if len(q.columns) == 0 {
|
|
if q.table == nil {
|
|
return nil, errNilModel
|
|
}
|
|
return q.table.DataFields, nil
|
|
}
|
|
return q._getFields(true)
|
|
}
|
|
|
|
func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {
|
|
fields := make([]*schema.Field, 0, len(q.columns))
|
|
for _, col := range q.columns {
|
|
if col.Args != nil {
|
|
continue
|
|
}
|
|
|
|
field, err := q.table.Field(col.Query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if omitPK && field.IsPK {
|
|
continue
|
|
}
|
|
|
|
fields = append(fields, field)
|
|
}
|
|
return fields, nil
|
|
}
|
|
|
|
func (q *baseQuery) scan(
|
|
ctx context.Context,
|
|
iquery Query,
|
|
query string,
|
|
model Model,
|
|
hasDest bool,
|
|
) (sql.Result, error) {
|
|
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
|
|
|
|
rows, err := q.conn.QueryContext(ctx, query)
|
|
if err != nil {
|
|
q.db.afterQuery(ctx, event, nil, err)
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
numRow, err := model.ScanRows(ctx, rows)
|
|
if err != nil {
|
|
q.db.afterQuery(ctx, event, nil, err)
|
|
return nil, err
|
|
}
|
|
|
|
if numRow == 0 && hasDest && isSingleRowModel(model) {
|
|
err = sql.ErrNoRows
|
|
}
|
|
|
|
res := driver.RowsAffected(numRow)
|
|
q.db.afterQuery(ctx, event, res, err)
|
|
|
|
return res, err
|
|
}
|
|
|
|
func (q *baseQuery) exec(
|
|
ctx context.Context,
|
|
iquery Query,
|
|
query string,
|
|
) (sql.Result, error) {
|
|
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
|
|
res, err := q.conn.ExecContext(ctx, query)
|
|
q.db.afterQuery(ctx, event, res, err)
|
|
return res, err
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
|
|
if q.table == nil {
|
|
return b, false
|
|
}
|
|
|
|
if m, ok := q.tableModel.(*structTableModel); ok {
|
|
if b, ok := m.AppendNamedArg(fmter, b, name); ok {
|
|
return b, ok
|
|
}
|
|
}
|
|
|
|
switch name {
|
|
case "TableName":
|
|
b = fmter.AppendQuery(b, string(q.table.SQLName))
|
|
return b, true
|
|
case "TableAlias":
|
|
b = fmter.AppendQuery(b, string(q.table.SQLAlias))
|
|
return b, true
|
|
case "PKs":
|
|
b = appendColumns(b, "", q.table.PKs)
|
|
return b, true
|
|
case "TablePKs":
|
|
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
|
|
return b, true
|
|
case "Columns":
|
|
b = appendColumns(b, "", q.table.Fields)
|
|
return b, true
|
|
case "TableColumns":
|
|
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
|
|
return b, true
|
|
}
|
|
|
|
return b, false
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func (q *baseQuery) Dialect() schema.Dialect {
|
|
return q.db.Dialect()
|
|
}
|
|
|
|
func (q *baseQuery) NewValues(model interface{}) *ValuesQuery {
|
|
return NewValuesQuery(q.db, model).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewSelect() *SelectQuery {
|
|
return NewSelectQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewInsert() *InsertQuery {
|
|
return NewInsertQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewUpdate() *UpdateQuery {
|
|
return NewUpdateQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewDelete() *DeleteQuery {
|
|
return NewDeleteQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewRaw(query string, args ...interface{}) *RawQuery {
|
|
return NewRawQuery(q.db, query, args...).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewCreateTable() *CreateTableQuery {
|
|
return NewCreateTableQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewDropTable() *DropTableQuery {
|
|
return NewDropTableQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewCreateIndex() *CreateIndexQuery {
|
|
return NewCreateIndexQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewDropIndex() *DropIndexQuery {
|
|
return NewDropIndexQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewTruncateTable() *TruncateTableQuery {
|
|
return NewTruncateTableQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewAddColumn() *AddColumnQuery {
|
|
return NewAddColumnQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
func (q *baseQuery) NewDropColumn() *DropColumnQuery {
|
|
return NewDropColumnQuery(q.db).Conn(q.conn)
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte {
|
|
for i, f := range fields {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
|
|
if len(table) > 0 {
|
|
b = append(b, table...)
|
|
b = append(b, '.')
|
|
}
|
|
b = append(b, f.SQLName...)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter {
|
|
if fmter.IsNop() {
|
|
return fmter
|
|
}
|
|
return fmter.WithArg(model)
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type whereBaseQuery struct {
|
|
baseQuery
|
|
|
|
where []schema.QueryWithSep
|
|
whereFields []*schema.Field
|
|
}
|
|
|
|
func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
|
|
q.where = append(q.where, where)
|
|
}
|
|
|
|
func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) {
|
|
if len(where) == 0 {
|
|
return
|
|
}
|
|
|
|
q.addWhere(schema.SafeQueryWithSep("", nil, sep))
|
|
q.addWhere(schema.SafeQueryWithSep("", nil, "("))
|
|
|
|
where[0].Sep = ""
|
|
q.where = append(q.where, where...)
|
|
|
|
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
|
|
}
|
|
|
|
func (q *whereBaseQuery) addWhereCols(cols []string) {
|
|
if q.table == nil {
|
|
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
if q.whereFields != nil {
|
|
err := errors.New("bun: WherePK can only be called once")
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
|
|
if cols == nil {
|
|
if err := q.table.CheckPKs(); err != nil {
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
q.whereFields = q.table.PKs
|
|
return
|
|
}
|
|
|
|
q.whereFields = make([]*schema.Field, len(cols))
|
|
for i, col := range cols {
|
|
field, err := q.table.Field(col)
|
|
if err != nil {
|
|
q.setErr(err)
|
|
return
|
|
}
|
|
q.whereFields[i] = field
|
|
}
|
|
}
|
|
|
|
func (q *whereBaseQuery) mustAppendWhere(
|
|
fmter schema.Formatter, b []byte, withAlias bool,
|
|
) ([]byte, error) {
|
|
if len(q.where) == 0 && q.whereFields == nil && !q.flags.Has(deletedFlag) {
|
|
err := errors.New("bun: Update and Delete queries require at least one Where")
|
|
return nil, err
|
|
}
|
|
return q.appendWhere(fmter, b, withAlias)
|
|
}
|
|
|
|
func (q *whereBaseQuery) appendWhere(
|
|
fmter schema.Formatter, b []byte, withAlias bool,
|
|
) (_ []byte, err error) {
|
|
if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() {
|
|
return b, nil
|
|
}
|
|
|
|
b = append(b, " WHERE "...)
|
|
startLen := len(b)
|
|
|
|
if len(q.where) > 0 {
|
|
b, err = appendWhere(fmter, b, q.where)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if q.isSoftDelete() {
|
|
if len(b) > startLen {
|
|
b = append(b, " AND "...)
|
|
}
|
|
|
|
if withAlias {
|
|
b = append(b, q.tableModel.Table().SQLAlias...)
|
|
} else {
|
|
b = append(b, q.tableModel.Table().SQLName...)
|
|
}
|
|
b = append(b, '.')
|
|
|
|
field := q.tableModel.Table().SoftDeleteField
|
|
b = append(b, field.SQLName...)
|
|
|
|
if field.IsPtr || field.NullZero {
|
|
if q.flags.Has(deletedFlag) {
|
|
b = append(b, " IS NOT NULL"...)
|
|
} else {
|
|
b = append(b, " IS NULL"...)
|
|
}
|
|
} else {
|
|
if q.flags.Has(deletedFlag) {
|
|
b = append(b, " != "...)
|
|
} else {
|
|
b = append(b, " = "...)
|
|
}
|
|
b = fmter.Dialect().AppendTime(b, time.Time{})
|
|
}
|
|
}
|
|
|
|
if q.whereFields != nil {
|
|
if len(b) > startLen {
|
|
b = append(b, " AND "...)
|
|
}
|
|
b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func appendWhere(
|
|
fmter schema.Formatter, b []byte, where []schema.QueryWithSep,
|
|
) (_ []byte, err error) {
|
|
for i, where := range where {
|
|
if i > 0 {
|
|
b = append(b, where.Sep...)
|
|
}
|
|
|
|
if where.Query == "" {
|
|
continue
|
|
}
|
|
|
|
b = append(b, '(')
|
|
b, err = where.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = append(b, ')')
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func (q *whereBaseQuery) appendWhereFields(
|
|
fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool,
|
|
) (_ []byte, err error) {
|
|
if q.table == nil {
|
|
err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model)
|
|
return nil, err
|
|
}
|
|
|
|
switch model := q.tableModel.(type) {
|
|
case *structTableModel:
|
|
return q.appendWhereStructFields(fmter, b, model, fields, withAlias)
|
|
case *sliceTableModel:
|
|
return q.appendWhereSliceFields(fmter, b, model, fields, withAlias)
|
|
default:
|
|
return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel)
|
|
}
|
|
}
|
|
|
|
func (q *whereBaseQuery) appendWhereStructFields(
|
|
fmter schema.Formatter,
|
|
b []byte,
|
|
model *structTableModel,
|
|
fields []*schema.Field,
|
|
withAlias bool,
|
|
) (_ []byte, err error) {
|
|
if !model.strct.IsValid() {
|
|
return nil, errNilModel
|
|
}
|
|
|
|
isTemplate := fmter.IsNop()
|
|
b = append(b, '(')
|
|
for i, f := range fields {
|
|
if i > 0 {
|
|
b = append(b, " AND "...)
|
|
}
|
|
if withAlias {
|
|
b = append(b, q.table.SQLAlias...)
|
|
b = append(b, '.')
|
|
}
|
|
b = append(b, f.SQLName...)
|
|
b = append(b, " = "...)
|
|
if isTemplate {
|
|
b = append(b, '?')
|
|
} else {
|
|
b = f.AppendValue(fmter, b, model.strct)
|
|
}
|
|
}
|
|
b = append(b, ')')
|
|
return b, nil
|
|
}
|
|
|
|
func (q *whereBaseQuery) appendWhereSliceFields(
|
|
fmter schema.Formatter,
|
|
b []byte,
|
|
model *sliceTableModel,
|
|
fields []*schema.Field,
|
|
withAlias bool,
|
|
) (_ []byte, err error) {
|
|
if len(fields) > 1 {
|
|
b = append(b, '(')
|
|
}
|
|
if withAlias {
|
|
b = appendColumns(b, q.table.SQLAlias, fields)
|
|
} else {
|
|
b = appendColumns(b, "", fields)
|
|
}
|
|
if len(fields) > 1 {
|
|
b = append(b, ')')
|
|
}
|
|
|
|
b = append(b, " IN ("...)
|
|
|
|
isTemplate := fmter.IsNop()
|
|
slice := model.slice
|
|
sliceLen := slice.Len()
|
|
for i := 0; i < sliceLen; i++ {
|
|
if i > 0 {
|
|
if isTemplate {
|
|
break
|
|
}
|
|
b = append(b, ", "...)
|
|
}
|
|
|
|
el := indirect(slice.Index(i))
|
|
|
|
if len(fields) > 1 {
|
|
b = append(b, '(')
|
|
}
|
|
for i, f := range fields {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
if isTemplate {
|
|
b = append(b, '?')
|
|
} else {
|
|
b = f.AppendValue(fmter, b, el)
|
|
}
|
|
}
|
|
if len(fields) > 1 {
|
|
b = append(b, ')')
|
|
}
|
|
}
|
|
|
|
b = append(b, ')')
|
|
|
|
return b, nil
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type returningQuery struct {
|
|
returning []schema.QueryWithArgs
|
|
returningFields []*schema.Field
|
|
}
|
|
|
|
func (q *returningQuery) addReturning(ret schema.QueryWithArgs) {
|
|
q.returning = append(q.returning, ret)
|
|
}
|
|
|
|
func (q *returningQuery) addReturningField(field *schema.Field) {
|
|
if len(q.returning) > 0 {
|
|
return
|
|
}
|
|
for _, f := range q.returningFields {
|
|
if f == field {
|
|
return
|
|
}
|
|
}
|
|
q.returningFields = append(q.returningFields, field)
|
|
}
|
|
|
|
func (q *returningQuery) appendReturning(
|
|
fmter schema.Formatter, b []byte,
|
|
) (_ []byte, err error) {
|
|
return q._appendReturning(fmter, b, "")
|
|
}
|
|
|
|
func (q *returningQuery) appendOutput(
|
|
fmter schema.Formatter, b []byte,
|
|
) (_ []byte, err error) {
|
|
return q._appendReturning(fmter, b, "INSERTED")
|
|
}
|
|
|
|
func (q *returningQuery) _appendReturning(
|
|
fmter schema.Formatter, b []byte, table string,
|
|
) (_ []byte, err error) {
|
|
for i, f := range q.returning {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = f.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if len(q.returning) > 0 {
|
|
return b, nil
|
|
}
|
|
|
|
b = appendColumns(b, schema.Safe(table), q.returningFields)
|
|
return b, nil
|
|
}
|
|
|
|
func (q *returningQuery) hasReturning() bool {
|
|
if len(q.returning) == 1 {
|
|
if ret := q.returning[0]; len(ret.Args) == 0 {
|
|
switch ret.Query {
|
|
case "", "null", "NULL":
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return len(q.returning) > 0 || len(q.returningFields) > 0
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type columnValue struct {
|
|
column string
|
|
value schema.QueryWithArgs
|
|
}
|
|
|
|
type customValueQuery struct {
|
|
modelValues map[string]schema.QueryWithArgs
|
|
extraValues []columnValue
|
|
}
|
|
|
|
func (q *customValueQuery) addValue(
|
|
table *schema.Table, column string, value string, args []interface{},
|
|
) {
|
|
ok := false
|
|
if table != nil {
|
|
_, ok = table.FieldMap[column]
|
|
}
|
|
|
|
if ok {
|
|
if q.modelValues == nil {
|
|
q.modelValues = make(map[string]schema.QueryWithArgs)
|
|
}
|
|
q.modelValues[column] = schema.SafeQuery(value, args)
|
|
} else {
|
|
q.extraValues = append(q.extraValues, columnValue{
|
|
column: column,
|
|
value: schema.SafeQuery(value, args),
|
|
})
|
|
}
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type setQuery struct {
|
|
set []schema.QueryWithArgs
|
|
}
|
|
|
|
func (q *setQuery) addSet(set schema.QueryWithArgs) {
|
|
q.set = append(q.set, set)
|
|
}
|
|
|
|
func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
|
|
for i, f := range q.set {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = f.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type cascadeQuery struct {
|
|
cascade bool
|
|
restrict bool
|
|
}
|
|
|
|
func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte {
|
|
if !fmter.HasFeature(feature.TableCascade) {
|
|
return b
|
|
}
|
|
if q.cascade {
|
|
b = append(b, " CASCADE"...)
|
|
}
|
|
if q.restrict {
|
|
b = append(b, " RESTRICT"...)
|
|
}
|
|
return b
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type idxHintsQuery struct {
|
|
use *indexHints
|
|
ignore *indexHints
|
|
force *indexHints
|
|
}
|
|
|
|
type indexHints struct {
|
|
names []schema.QueryWithArgs
|
|
forJoin []schema.QueryWithArgs
|
|
forOrderBy []schema.QueryWithArgs
|
|
forGroupBy []schema.QueryWithArgs
|
|
}
|
|
|
|
func (ih *idxHintsQuery) lazyUse() *indexHints {
|
|
if ih.use == nil {
|
|
ih.use = new(indexHints)
|
|
}
|
|
return ih.use
|
|
}
|
|
|
|
func (ih *idxHintsQuery) lazyIgnore() *indexHints {
|
|
if ih.ignore == nil {
|
|
ih.ignore = new(indexHints)
|
|
}
|
|
return ih.ignore
|
|
}
|
|
|
|
func (ih *idxHintsQuery) lazyForce() *indexHints {
|
|
if ih.force == nil {
|
|
ih.force = new(indexHints)
|
|
}
|
|
return ih.force
|
|
}
|
|
|
|
func (ih *idxHintsQuery) appendIndexes(hints []schema.QueryWithArgs, indexes ...string) []schema.QueryWithArgs {
|
|
for _, idx := range indexes {
|
|
hints = append(hints, schema.UnsafeIdent(idx))
|
|
}
|
|
return hints
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addUseIndex(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyUse().names = ih.appendIndexes(ih.use.names, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addUseIndexForJoin(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyUse().forJoin = ih.appendIndexes(ih.use.forJoin, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addUseIndexForOrderBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyUse().forOrderBy = ih.appendIndexes(ih.use.forOrderBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addUseIndexForGroupBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyUse().forGroupBy = ih.appendIndexes(ih.use.forGroupBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addIgnoreIndex(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyIgnore().names = ih.appendIndexes(ih.ignore.names, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addIgnoreIndexForJoin(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyIgnore().forJoin = ih.appendIndexes(ih.ignore.forJoin, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addIgnoreIndexForOrderBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyIgnore().forOrderBy = ih.appendIndexes(ih.ignore.forOrderBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addIgnoreIndexForGroupBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyIgnore().forGroupBy = ih.appendIndexes(ih.ignore.forGroupBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addForceIndex(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyForce().names = ih.appendIndexes(ih.force.names, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addForceIndexForJoin(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyForce().forJoin = ih.appendIndexes(ih.force.forJoin, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addForceIndexForOrderBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyForce().forOrderBy = ih.appendIndexes(ih.force.forOrderBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) addForceIndexForGroupBy(indexes ...string) {
|
|
if len(indexes) == 0 {
|
|
return
|
|
}
|
|
ih.lazyForce().forGroupBy = ih.appendIndexes(ih.force.forGroupBy, indexes...)
|
|
}
|
|
|
|
func (ih *idxHintsQuery) appendIndexHints(
|
|
fmter schema.Formatter, b []byte,
|
|
) ([]byte, error) {
|
|
type IdxHint struct {
|
|
Name string
|
|
Values []schema.QueryWithArgs
|
|
}
|
|
|
|
var hints []IdxHint
|
|
if ih.use != nil {
|
|
hints = append(hints, []IdxHint{
|
|
{
|
|
Name: "USE INDEX",
|
|
Values: ih.use.names,
|
|
},
|
|
{
|
|
Name: "USE INDEX FOR JOIN",
|
|
Values: ih.use.forJoin,
|
|
},
|
|
{
|
|
Name: "USE INDEX FOR ORDER BY",
|
|
Values: ih.use.forOrderBy,
|
|
},
|
|
{
|
|
Name: "USE INDEX FOR GROUP BY",
|
|
Values: ih.use.forGroupBy,
|
|
},
|
|
}...)
|
|
}
|
|
|
|
if ih.ignore != nil {
|
|
hints = append(hints, []IdxHint{
|
|
{
|
|
Name: "IGNORE INDEX",
|
|
Values: ih.ignore.names,
|
|
},
|
|
{
|
|
Name: "IGNORE INDEX FOR JOIN",
|
|
Values: ih.ignore.forJoin,
|
|
},
|
|
{
|
|
Name: "IGNORE INDEX FOR ORDER BY",
|
|
Values: ih.ignore.forOrderBy,
|
|
},
|
|
{
|
|
Name: "IGNORE INDEX FOR GROUP BY",
|
|
Values: ih.ignore.forGroupBy,
|
|
},
|
|
}...)
|
|
}
|
|
|
|
if ih.force != nil {
|
|
hints = append(hints, []IdxHint{
|
|
{
|
|
Name: "FORCE INDEX",
|
|
Values: ih.force.names,
|
|
},
|
|
{
|
|
Name: "FORCE INDEX FOR JOIN",
|
|
Values: ih.force.forJoin,
|
|
},
|
|
{
|
|
Name: "FORCE INDEX FOR ORDER BY",
|
|
Values: ih.force.forOrderBy,
|
|
},
|
|
{
|
|
Name: "FORCE INDEX FOR GROUP BY",
|
|
Values: ih.force.forGroupBy,
|
|
},
|
|
}...)
|
|
}
|
|
|
|
var err error
|
|
for _, h := range hints {
|
|
b, err = ih.bufIndexHint(h.Name, h.Values, fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func (ih *idxHintsQuery) bufIndexHint(
|
|
name string,
|
|
hints []schema.QueryWithArgs,
|
|
fmter schema.Formatter, b []byte,
|
|
) ([]byte, error) {
|
|
var err error
|
|
if len(hints) == 0 {
|
|
return b, nil
|
|
}
|
|
b = append(b, fmt.Sprintf(" %s (", name)...)
|
|
for i, f := range hints {
|
|
if i > 0 {
|
|
b = append(b, ", "...)
|
|
}
|
|
b, err = f.AppendQuery(fmter, b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
b = append(b, ")"...)
|
|
return b, nil
|
|
}
|