mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2024-12-22 10:43:12 +00:00
323 lines
7.4 KiB
Go
323 lines
7.4 KiB
Go
|
package validator
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
)
|
||
|
|
||
|
type tagType uint8
|
||
|
|
||
|
const (
|
||
|
typeDefault tagType = iota
|
||
|
typeOmitEmpty
|
||
|
typeIsDefault
|
||
|
typeNoStructLevel
|
||
|
typeStructOnly
|
||
|
typeDive
|
||
|
typeOr
|
||
|
typeKeys
|
||
|
typeEndKeys
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
invalidValidation = "Invalid validation tag on field '%s'"
|
||
|
undefinedValidation = "Undefined validation function '%s' on field '%s'"
|
||
|
keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
|
||
|
)
|
||
|
|
||
|
type structCache struct {
|
||
|
lock sync.Mutex
|
||
|
m atomic.Value // map[reflect.Type]*cStruct
|
||
|
}
|
||
|
|
||
|
func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
|
||
|
c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (sc *structCache) Set(key reflect.Type, value *cStruct) {
|
||
|
m := sc.m.Load().(map[reflect.Type]*cStruct)
|
||
|
nm := make(map[reflect.Type]*cStruct, len(m)+1)
|
||
|
for k, v := range m {
|
||
|
nm[k] = v
|
||
|
}
|
||
|
nm[key] = value
|
||
|
sc.m.Store(nm)
|
||
|
}
|
||
|
|
||
|
type tagCache struct {
|
||
|
lock sync.Mutex
|
||
|
m atomic.Value // map[string]*cTag
|
||
|
}
|
||
|
|
||
|
func (tc *tagCache) Get(key string) (c *cTag, found bool) {
|
||
|
c, found = tc.m.Load().(map[string]*cTag)[key]
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (tc *tagCache) Set(key string, value *cTag) {
|
||
|
m := tc.m.Load().(map[string]*cTag)
|
||
|
nm := make(map[string]*cTag, len(m)+1)
|
||
|
for k, v := range m {
|
||
|
nm[k] = v
|
||
|
}
|
||
|
nm[key] = value
|
||
|
tc.m.Store(nm)
|
||
|
}
|
||
|
|
||
|
type cStruct struct {
|
||
|
name string
|
||
|
fields []*cField
|
||
|
fn StructLevelFuncCtx
|
||
|
}
|
||
|
|
||
|
type cField struct {
|
||
|
idx int
|
||
|
name string
|
||
|
altName string
|
||
|
namesEqual bool
|
||
|
cTags *cTag
|
||
|
}
|
||
|
|
||
|
type cTag struct {
|
||
|
tag string
|
||
|
aliasTag string
|
||
|
actualAliasTag string
|
||
|
param string
|
||
|
keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
|
||
|
next *cTag
|
||
|
fn FuncCtx
|
||
|
typeof tagType
|
||
|
hasTag bool
|
||
|
hasAlias bool
|
||
|
hasParam bool // true if parameter used eg. eq= where the equal sign has been set
|
||
|
isBlockEnd bool // indicates the current tag represents the last validation in the block
|
||
|
runValidationWhenNil bool
|
||
|
}
|
||
|
|
||
|
func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
|
||
|
v.structCache.lock.Lock()
|
||
|
defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
|
||
|
|
||
|
typ := current.Type()
|
||
|
|
||
|
// could have been multiple trying to access, but once first is done this ensures struct
|
||
|
// isn't parsed again.
|
||
|
cs, ok := v.structCache.Get(typ)
|
||
|
if ok {
|
||
|
return cs
|
||
|
}
|
||
|
|
||
|
cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
|
||
|
|
||
|
numFields := current.NumField()
|
||
|
|
||
|
var ctag *cTag
|
||
|
var fld reflect.StructField
|
||
|
var tag string
|
||
|
var customName string
|
||
|
|
||
|
for i := 0; i < numFields; i++ {
|
||
|
|
||
|
fld = typ.Field(i)
|
||
|
|
||
|
if !fld.Anonymous && len(fld.PkgPath) > 0 {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
tag = fld.Tag.Get(v.tagName)
|
||
|
|
||
|
if tag == skipValidationTag {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
customName = fld.Name
|
||
|
|
||
|
if v.hasTagNameFunc {
|
||
|
name := v.tagNameFunc(fld)
|
||
|
if len(name) > 0 {
|
||
|
customName = name
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
|
||
|
// and so only struct level caching can be used instead of combined with Field tag caching
|
||
|
|
||
|
if len(tag) > 0 {
|
||
|
ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
|
||
|
} else {
|
||
|
// even if field doesn't have validations need cTag for traversing to potential inner/nested
|
||
|
// elements of the field.
|
||
|
ctag = new(cTag)
|
||
|
}
|
||
|
|
||
|
cs.fields = append(cs.fields, &cField{
|
||
|
idx: i,
|
||
|
name: fld.Name,
|
||
|
altName: customName,
|
||
|
cTags: ctag,
|
||
|
namesEqual: fld.Name == customName,
|
||
|
})
|
||
|
}
|
||
|
v.structCache.Set(typ, cs)
|
||
|
return cs
|
||
|
}
|
||
|
|
||
|
func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
|
||
|
var t string
|
||
|
noAlias := len(alias) == 0
|
||
|
tags := strings.Split(tag, tagSeparator)
|
||
|
|
||
|
for i := 0; i < len(tags); i++ {
|
||
|
t = tags[i]
|
||
|
if noAlias {
|
||
|
alias = t
|
||
|
}
|
||
|
|
||
|
// check map for alias and process new tags, otherwise process as usual
|
||
|
if tagsVal, found := v.aliases[t]; found {
|
||
|
if i == 0 {
|
||
|
firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
|
||
|
} else {
|
||
|
next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
|
||
|
current.next, current = next, curr
|
||
|
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
var prevTag tagType
|
||
|
|
||
|
if i == 0 {
|
||
|
current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
|
||
|
firstCtag = current
|
||
|
} else {
|
||
|
prevTag = current.typeof
|
||
|
current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
|
||
|
current = current.next
|
||
|
}
|
||
|
|
||
|
switch t {
|
||
|
case diveTag:
|
||
|
current.typeof = typeDive
|
||
|
continue
|
||
|
|
||
|
case keysTag:
|
||
|
current.typeof = typeKeys
|
||
|
|
||
|
if i == 0 || prevTag != typeDive {
|
||
|
panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
|
||
|
}
|
||
|
|
||
|
current.typeof = typeKeys
|
||
|
|
||
|
// need to pass along only keys tag
|
||
|
// need to increment i to skip over the keys tags
|
||
|
b := make([]byte, 0, 64)
|
||
|
|
||
|
i++
|
||
|
|
||
|
for ; i < len(tags); i++ {
|
||
|
|
||
|
b = append(b, tags[i]...)
|
||
|
b = append(b, ',')
|
||
|
|
||
|
if tags[i] == endKeysTag {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
|
||
|
continue
|
||
|
|
||
|
case endKeysTag:
|
||
|
current.typeof = typeEndKeys
|
||
|
|
||
|
// if there are more in tags then there was no keysTag defined
|
||
|
// and an error should be thrown
|
||
|
if i != len(tags)-1 {
|
||
|
panic(keysTagNotDefined)
|
||
|
}
|
||
|
return
|
||
|
|
||
|
case omitempty:
|
||
|
current.typeof = typeOmitEmpty
|
||
|
continue
|
||
|
|
||
|
case structOnlyTag:
|
||
|
current.typeof = typeStructOnly
|
||
|
continue
|
||
|
|
||
|
case noStructLevelTag:
|
||
|
current.typeof = typeNoStructLevel
|
||
|
continue
|
||
|
|
||
|
default:
|
||
|
if t == isdefault {
|
||
|
current.typeof = typeIsDefault
|
||
|
}
|
||
|
// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
|
||
|
orVals := strings.Split(t, orSeparator)
|
||
|
|
||
|
for j := 0; j < len(orVals); j++ {
|
||
|
vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
|
||
|
if noAlias {
|
||
|
alias = vals[0]
|
||
|
current.aliasTag = alias
|
||
|
} else {
|
||
|
current.actualAliasTag = t
|
||
|
}
|
||
|
|
||
|
if j > 0 {
|
||
|
current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
|
||
|
current = current.next
|
||
|
}
|
||
|
current.hasParam = len(vals) > 1
|
||
|
|
||
|
current.tag = vals[0]
|
||
|
if len(current.tag) == 0 {
|
||
|
panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
|
||
|
}
|
||
|
|
||
|
if wrapper, ok := v.validations[current.tag]; ok {
|
||
|
current.fn = wrapper.fn
|
||
|
current.runValidationWhenNil = wrapper.runValidatinOnNil
|
||
|
} else {
|
||
|
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
|
||
|
}
|
||
|
|
||
|
if len(orVals) > 1 {
|
||
|
current.typeof = typeOr
|
||
|
}
|
||
|
|
||
|
if len(vals) > 1 {
|
||
|
current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
|
||
|
}
|
||
|
}
|
||
|
current.isBlockEnd = true
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (v *Validate) fetchCacheTag(tag string) *cTag {
|
||
|
// find cached tag
|
||
|
ctag, found := v.tagCache.Get(tag)
|
||
|
if !found {
|
||
|
v.tagCache.lock.Lock()
|
||
|
defer v.tagCache.lock.Unlock()
|
||
|
|
||
|
// could have been multiple trying to access, but once first is done this ensures tag
|
||
|
// isn't parsed again.
|
||
|
ctag, found = v.tagCache.Get(tag)
|
||
|
if !found {
|
||
|
ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
|
||
|
v.tagCache.Set(tag, ctag)
|
||
|
}
|
||
|
}
|
||
|
return ctag
|
||
|
}
|