gotosocial/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go
tobi f2e5bedea6
migrate go version to 1.17 (#203)
* migrate go version to 1.17

* update contributing
2021-09-10 14:42:14 +02:00

303 lines
6.4 KiB
Go

package pgdialect
import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/schema"
)
var (
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
stringType = reflect.TypeOf((*string)(nil)).Elem()
sliceStringType = reflect.TypeOf([]string(nil))
intType = reflect.TypeOf((*int)(nil)).Elem()
sliceIntType = reflect.TypeOf([]int(nil))
int64Type = reflect.TypeOf((*int64)(nil)).Elem()
sliceInt64Type = reflect.TypeOf([]int64(nil))
float64Type = reflect.TypeOf((*float64)(nil)).Elem()
sliceFloat64Type = reflect.TypeOf([]float64(nil))
)
func customAppender(typ reflect.Type) schema.AppenderFunc {
switch typ.Kind() {
case reflect.Uint32:
return appendUint32ValueAsInt
case reflect.Uint, reflect.Uint64:
return appendUint64ValueAsInt
}
return nil
}
func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(int32(v.Uint())), 10)
}
func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(v.Uint()), 10)
}
//------------------------------------------------------------------------------
func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
return dialect.AppendTime(b, v)
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
}
func arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Kind() == reflect.String {
return arrayAppendStringValue
}
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
return schema.Appender(typ, customAppender)
}
func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
return dialect.AppendError(b, err)
}
return arrayAppend(fmter, b, iface)
}
//------------------------------------------------------------------------------
func arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := arrayAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return appendStringSliceValue
case intType:
return appendIntSliceValue
case int64Type:
return appendInt64SliceValue
case float64Type:
return appendFloat64SliceValue
}
}
appendElem := arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
if v.IsNil() {
return dialect.AppendNull(b)
}
}
if kind == reflect.Ptr {
v = v.Elem()
}
b = append(b, '\'')
b = append(b, '{')
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
b = appendElem(fmter, b, elem)
b = append(b, ',')
}
if v.Len() > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
}
func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
}
func appendStringSlice(b []byte, ss []string) []byte {
if ss == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, s := range ss {
b = arrayAppendString(b, s)
b = append(b, ',')
}
if len(ss) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceIntType).Interface().([]int)
return appendIntSlice(b, ints)
}
func appendIntSlice(b []byte, ints []int) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceInt64Type).Interface().([]int64)
return appendInt64Slice(b, ints)
}
func appendInt64Slice(b []byte, ints []int64) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, n, 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
floats := v.Convert(sliceFloat64Type).Interface().([]float64)
return appendFloat64Slice(b, floats)
}
func appendFloat64Slice(b []byte, floats []float64) []byte {
if floats == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
switch r {
case 0:
// ignore
case '\'':
b = append(b, "'''"...)
case '"':
b = append(b, '\\', '"')
case '\\':
b = append(b, '\\', '\\')
default:
if r < utf8.RuneSelf {
b = append(b, byte(r))
break
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
}
b = append(b, '"')
return b
}