mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2024-11-10 06:54:16 +00:00
Api/v1/accounts (#8)
* start work on accounts module * plodding away on the accounts endpoint * groundwork for other account routes * add password validator * validation utils * require account approval flags * comments * comments * go fmt * comments * add distributor stub * rename api to federator * tidy a bit * validate new account requests * rename r router * comments * add domain blocks * add some more shortcuts * add some more shortcuts * check email + username availability * email block checking for signups * chunking away at it * tick off a few more things * some fiddling with tests * add mock package * relocate repo * move mocks around * set app id on new signups * initialize oauth server properly * rename oauth server * proper mocking tests * go fmt ./... * add required fields * change name of func * move validation to account.go * more tests! * add some file utility tools * add mediaconfig * new shortcut * add some more fields * add followrequest model * add notify * update mastotypes * mock out storage interface * start building media interface * start on update credentials * mess about with media a bit more * test image manipulation * media more or less working * account update nearly working * rearranging my package ;) ;) ;) * phew big stuff!!!! * fix type checking * *fiddles* * Add CreateTables func * account registration flow working * tidy * script to step through auth flow * add a lil helper for generating user uris * fiddling with federation a bit * update progress * Tidying and linting
This commit is contained in:
parent
aa9ce272dc
commit
71a49e2b43
94 changed files with 6585 additions and 955 deletions
10
PROGRESS.md
10
PROGRESS.md
|
@ -11,10 +11,10 @@
|
|||
* [x] /auth/sign_in GET (Show form for user signin)
|
||||
* [x] /auth/sign_in POST (Validate username and password and sign user in)
|
||||
* [ ] Accounts
|
||||
* [ ] /api/v1/accounts POST (Register a new account)
|
||||
* [ ] /api/v1/accounts/verify_credentials GET (Verify account credentials with a user token)
|
||||
* [ ] /api/v1/accounts/update_credentials PATCH (Update user's display name/preferences)
|
||||
* [ ] /api/v1/accounts/:id GET (Get account information)
|
||||
* [x] /api/v1/accounts POST (Register a new account)
|
||||
* [x] /api/v1/accounts/verify_credentials GET (Verify account credentials with a user token)
|
||||
* [x] /api/v1/accounts/update_credentials PATCH (Update user's display name/preferences)
|
||||
* [x] /api/v1/accounts/:id GET (Get account information)
|
||||
* [ ] /api/v1/accounts/:id/statuses GET (Get an account's statuses)
|
||||
* [ ] /api/v1/accounts/:id/followers GET (Get an account's followers)
|
||||
* [ ] /api/v1/accounts/:id/following GET (Get an account's following)
|
||||
|
@ -184,7 +184,7 @@
|
|||
* [ ] Cache
|
||||
* [ ] In-memory cache
|
||||
* [ ] Security features
|
||||
* [ ] Authorization middleware
|
||||
* [x] Authorization middleware
|
||||
* [ ] Rate limiting middleware
|
||||
* [ ] Scope middleware
|
||||
* [ ] Permissions/acl middleware for admins+moderators
|
||||
|
|
|
@ -22,12 +22,12 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/action"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/gotosocial"
|
||||
"github.com/gotosocial/gotosocial/internal/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/action"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
@ -111,10 +111,70 @@ func main() {
|
|||
// TEMPLATE FLAGS
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.TemplateBaseDir,
|
||||
Usage: "Basedir for html templating files for rendering pages and composing emails",
|
||||
Usage: "Basedir for html templating files for rendering pages and composing emails.",
|
||||
Value: "./web/template/",
|
||||
EnvVars: []string{envNames.TemplateBaseDir},
|
||||
},
|
||||
|
||||
// ACCOUNTS FLAGS
|
||||
&cli.BoolFlag{
|
||||
Name: flagNames.AccountsOpenRegistration,
|
||||
Usage: "Allow anyone to submit an account signup request. If false, server will be invite-only.",
|
||||
Value: true,
|
||||
EnvVars: []string{envNames.AccountsOpenRegistration},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagNames.AccountsRequireApproval,
|
||||
Usage: "Do account signups require approval by an admin or moderator before user can log in? If false, new registrations will be automatically approved.",
|
||||
Value: true,
|
||||
EnvVars: []string{envNames.AccountsRequireApproval},
|
||||
},
|
||||
|
||||
// MEDIA FLAGS
|
||||
&cli.IntFlag{
|
||||
Name: flagNames.MediaMaxImageSize,
|
||||
Usage: "Max size of accepted images in bytes",
|
||||
Value: 1048576, // 1mb
|
||||
EnvVars: []string{envNames.MediaMaxImageSize},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: flagNames.MediaMaxVideoSize,
|
||||
Usage: "Max size of accepted videos in bytes",
|
||||
Value: 5242880, // 5mb
|
||||
EnvVars: []string{envNames.MediaMaxVideoSize},
|
||||
},
|
||||
|
||||
// STORAGE FLAGS
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.StorageBackend,
|
||||
Usage: "Storage backend to use for media attachments",
|
||||
Value: "local",
|
||||
EnvVars: []string{envNames.StorageBackend},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.StorageBasePath,
|
||||
Usage: "Full path to an already-created directory where gts should store/retrieve media files",
|
||||
Value: "/opt/gotosocial",
|
||||
EnvVars: []string{envNames.StorageBasePath},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.StorageServeProtocol,
|
||||
Usage: "Protocol to use for serving media attachments (use https if storage is local)",
|
||||
Value: "https",
|
||||
EnvVars: []string{envNames.StorageServeProtocol},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.StorageServeHost,
|
||||
Usage: "Hostname to serve media attachments from (use the same value as host if storage is local)",
|
||||
Value: "localhost",
|
||||
EnvVars: []string{envNames.StorageServeHost},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagNames.StorageServeBasePath,
|
||||
Usage: "Path to append to protocol and hostname to create the base path from which media files will be served (default will mostly be fine)",
|
||||
Value: "/fileserver/media",
|
||||
EnvVars: []string{envNames.StorageServeBasePath},
|
||||
},
|
||||
},
|
||||
Commands: []*cli.Command{
|
||||
{
|
||||
|
|
|
@ -14,10 +14,9 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
###################
|
||||
##### CONFIG ######
|
||||
###################
|
||||
|
||||
###########################
|
||||
##### GENERAL CONFIG ######
|
||||
###########################
|
||||
# String. Log level to use throughout the application. Must be lower-case.
|
||||
# Options: ["debug","info","warn","error","fatal"]
|
||||
# Default: "info"
|
||||
|
@ -39,6 +38,9 @@ host: "localhost"
|
|||
# Default: "https"
|
||||
protocol: "https"
|
||||
|
||||
############################
|
||||
##### DATABASE CONFIG ######
|
||||
############################
|
||||
# Config pertaining to the Gotosocial database connection
|
||||
db:
|
||||
# String. Database type.
|
||||
|
@ -72,9 +74,26 @@ db:
|
|||
# Default: "postgres"
|
||||
database: "postgres"
|
||||
|
||||
###############################
|
||||
##### WEB TEMPLATE CONFIG #####
|
||||
###############################
|
||||
# Config pertaining to templating of web pages/email notifications and the like
|
||||
template:
|
||||
# String. Directory from which gotosocial will attempt to load html templates (.tmpl files).
|
||||
# Examples: ["/some/absolute/path/", "./relative/path/", "../../some/weird/path/"]
|
||||
# Default: "./web/template/"
|
||||
baseDir: "./web/template/"
|
||||
|
||||
###########################
|
||||
##### ACCOUNTS CONFIG #####
|
||||
###########################
|
||||
# Config pertaining to creation and maintenance of accounts on the server, as well as defaults for new accounts.
|
||||
accounts:
|
||||
# Bool. Do we want people to be able to just submit sign up requests, or do we want invite only?
|
||||
# Options: [true, false]
|
||||
# Default: true
|
||||
openRegistration: true
|
||||
# Bool. Do sign up requests require approval from an admin/moderator before an account can sign in/use the server?
|
||||
# Options: [true, false]
|
||||
# Default: true
|
||||
requireApproval: true
|
||||
|
|
13
go.mod
13
go.mod
|
@ -1,8 +1,10 @@
|
|||
module github.com/gotosocial/gotosocial
|
||||
module github.com/superseriousbusiness/gotosocial
|
||||
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/buckket/go-blurhash v1.1.0
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
|
||||
github.com/gin-contrib/sessions v0.0.3
|
||||
github.com/gin-gonic/gin v1.6.3
|
||||
github.com/go-fed/activity v1.0.0
|
||||
|
@ -10,16 +12,23 @@ require (
|
|||
github.com/go-pg/pg/v10 v10.8.0
|
||||
github.com/golang/mock v1.4.4 // indirect
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88
|
||||
github.com/h2non/filetype v1.1.1
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.1 // indirect
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/onsi/ginkgo v1.15.0 // indirect
|
||||
github.com/onsi/gomega v1.10.5 // indirect
|
||||
github.com/sirupsen/logrus v1.8.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
|
||||
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f
|
||||
github.com/tidwall/btree v0.4.2 // indirect
|
||||
github.com/tidwall/buntdb v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.1.0 // indirect
|
||||
github.com/urfave/cli/v2 v2.3.0
|
||||
github.com/wagslane/go-password-validator v0.3.0
|
||||
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b
|
||||
golang.org/x/text v0.3.3
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
gopkg.in/yaml.v2 v2.3.0
|
||||
)
|
||||
|
|
57
go.sum
57
go.sum
|
@ -7,16 +7,37 @@ github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu
|
|||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20190329173943-551aad21a668/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
|
||||
github.com/bradleypeabody/gorilla-sessions-memcache v0.0.0-20181103040241-659414f458e1/go.mod h1:dkChI7Tbtx7H1Tj7TqGSZMOeGpMP5gLHtjroHd4agiI=
|
||||
github.com/buckket/go-blurhash v1.1.0 h1:X5M6r0LIvwdvKiUtiNcRL2YlmOfMzYobI3VCKCZc9Do=
|
||||
github.com/buckket/go-blurhash v1.1.0/go.mod h1:aT2iqo5W9vu9GpyoLErKfTHwgODsZp3bQfXjXJUxNb8=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/dave/jennifer v1.3.0/go.mod h1:fIb+770HOpJ2fmN9EPPKOqm1vMGhB+TwXKMZhrIygKg=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/dsoprea/go-exif v0.0.0-20210131231135-d154f10435cc h1:AuzYp98IFVOi0NU/WcZyGDQ6vAh/zkCjxGD3kt8aLzA=
|
||||
github.com/dsoprea/go-exif v0.0.0-20210131231135-d154f10435cc/go.mod h1:lOaOt7+UEppOgyvRy749v3do836U/hw0YVJNjoyPaEs=
|
||||
github.com/dsoprea/go-exif/v2 v2.0.0-20200321225314-640175a69fe4/go.mod h1:Lm2lMM2zx8p4a34ZemkaUV95AnMl4ZvLbCUbwOvLC2E=
|
||||
github.com/dsoprea/go-exif/v2 v2.0.0-20200604193436-ca8584a0e1c4 h1:Mg7pY7kxDQD2Bkvr1N+XW4BESSIQ7tTTR7Vv+Gi2CsM=
|
||||
github.com/dsoprea/go-exif/v2 v2.0.0-20200604193436-ca8584a0e1c4/go.mod h1:9EXlPeHfblFFnwu5UOqmP2eoZfJyAZ2Ri/Vki33ajO0=
|
||||
github.com/dsoprea/go-iptc v0.0.0-20200609062250-162ae6b44feb h1:gwjJjUr6FY7zAWVEueFPrcRHhd9+IK81TcItbqw2du4=
|
||||
github.com/dsoprea/go-iptc v0.0.0-20200609062250-162ae6b44feb/go.mod h1:kYIdx9N9NaOyD7U6D+YtExN7QhRm+5kq7//yOsRXQtM=
|
||||
github.com/dsoprea/go-jpeg-image-structure v0.0.0-20210128210355-86b1014917f2 h1:ULCSN6v0WISNbALxomGPXh4dSjRKPW+7+seYoMz8UTc=
|
||||
github.com/dsoprea/go-jpeg-image-structure v0.0.0-20210128210355-86b1014917f2/go.mod h1:ZoOP3yUG0HD1T4IUjIFsz/2OAB2yB4YX6NSm4K+uJRg=
|
||||
github.com/dsoprea/go-logging v0.0.0-20190624164917-c4f10aab7696/go.mod h1:Nm/x2ZUNRW6Fe5C3LxdY1PyZY5wmDv/s5dkPJ/VB3iA=
|
||||
github.com/dsoprea/go-logging v0.0.0-20200517223158-a10564966e9d h1:F/7L5wr/fP/SKeO5HuMlNEX9Ipyx2MbH2rV9G4zJRpk=
|
||||
github.com/dsoprea/go-logging v0.0.0-20200517223158-a10564966e9d/go.mod h1:7I+3Pe2o/YSU88W0hWlm9S22W7XI1JFNJ86U0zPKMf8=
|
||||
github.com/dsoprea/go-photoshop-info-format v0.0.0-20200609050348-3db9b63b202c h1:7j5aWACOzROpr+dvMtu8GnI97g9ShLWD72XIELMgn+c=
|
||||
github.com/dsoprea/go-photoshop-info-format v0.0.0-20200609050348-3db9b63b202c/go.mod h1:pqKB+ijp27cEcrHxhXVgUUMlSDRuGJJp1E+20Lj5H0E=
|
||||
github.com/dsoprea/go-png-image-structure v0.0.0-20200807080309-a98d4e94ac82 h1:RdwKOEEe2ND/JmoKh6I/EQlR9idKJTDOMffPFK6vN2M=
|
||||
github.com/dsoprea/go-png-image-structure v0.0.0-20200807080309-a98d4e94ac82/go.mod h1:aDYQkL/5gfRNZkoxiLTSWU4Y8/gV/4MVsy/MU9uwTak=
|
||||
github.com/dsoprea/go-utility v0.0.0-20200512094054-1abbbc781176 h1:CfXezFYb2STGOd1+n1HshvE191zVx+QX3A1nML5xxME=
|
||||
github.com/dsoprea/go-utility v0.0.0-20200512094054-1abbbc781176/go.mod h1:95+K3z2L0mqsVYd6yveIv1lmtT3tcQQ3dVakPySffW8=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8=
|
||||
|
@ -35,6 +56,9 @@ github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmC
|
|||
github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14=
|
||||
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
|
||||
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
|
||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||
github.com/go-errors/errors v1.0.2 h1:xMxH9j2fNg/L4hLn/4y3M0IUsn0M6Wbu/Uh9QlOfBh4=
|
||||
github.com/go-errors/errors v1.0.2/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
|
||||
github.com/go-fed/activity v1.0.0 h1:j7w3auHZnVCjUcgA1mE+UqSOjFBhvW2Z2res3vNol+o=
|
||||
github.com/go-fed/activity v1.0.0/go.mod h1:v4QoPaAzjWZ8zN2VFVGL5ep9C02mst0hQYHUpQwso4Q=
|
||||
github.com/go-fed/httpsig v0.1.1-0.20190914113940-c2de3672e5b5 h1:WLvFZqoXnuVTBKA6U/1FnEHNQ0Rq0QM0rGhY8Tx6R1g=
|
||||
|
@ -58,6 +82,11 @@ github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1
|
|||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||
github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg=
|
||||
github.com/go-test/deep v1.0.1/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
|
||||
github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b h1:khEcpUM4yFcxg4/FHQWkvVRmgijNXRfzkIDHh23ggEo=
|
||||
github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM=
|
||||
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
|
||||
github.com/golang/geo v0.0.0-20200319012246-673a6f80352d h1:C/hKUcHT483btRbeGkrRjJz+Zbcj8audldIi9tRJDCc=
|
||||
github.com/golang/geo v0.0.0-20200319012246-673a6f80352d/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
|
@ -103,11 +132,12 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R
|
|||
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU=
|
||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
|
||||
github.com/h2non/filetype v1.1.1 h1:xvOwnXKAckvtLWsN398qS9QhlxlnVXBjXBydK2/UFB4=
|
||||
github.com/h2non/filetype v1.1.1/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
||||
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
|
@ -136,12 +166,16 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y
|
|||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
|
||||
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
@ -174,6 +208,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE
|
|||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
|
@ -182,6 +217,10 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
|
|||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203 h1:1SWXcTphBQjYGWRRxLFIAR1LVtQEj4eR7xPtyeOVM/c=
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203/go.mod h1:0Xw5cYMOYpgaWs+OOSx41ugycl2qvKTi9tlMMcZhFyY=
|
||||
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f h1:0YcjA/ieDuDFHJPg5w2hk3r5kIWNvEyl7GsoArxdI3s=
|
||||
github.com/superseriousbusiness/oauth2/v4 v4.2.1-0.20210327102222-902aba1ef45f/go.mod h1:8p0a/BEN9hhsGzE3tPaFFlIZgxAaLyLN5KY0bPg9ZBc=
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
|
||||
github.com/tidwall/btree v0.3.0/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
|
||||
github.com/tidwall/btree v0.4.2 h1:aLwwJlG+InuFzdAPuBf9YCAR1LvSQ9zhC5aorFPlIPs=
|
||||
|
@ -235,6 +274,8 @@ github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vb
|
|||
github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I=
|
||||
github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
|
||||
|
@ -280,9 +321,14 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200320220750-118fecf932d8/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
|
@ -371,6 +417,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD
|
|||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
@ -21,8 +21,8 @@ package action
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
// GTSAction defines one *action* that can be taken by the gotosocial cli command.
|
||||
|
|
32
internal/action/mock_GTSAction.go
Normal file
32
internal/action/mock_GTSAction.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package action
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
config "github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockGTSAction is an autogenerated mock type for the GTSAction type
|
||||
type MockGTSAction struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Execute provides a mock function with given fields: _a0, _a1, _a2
|
||||
func (_m *MockGTSAction) Execute(_a0 context.Context, _a1 *config.Config, _a2 *logrus.Logger) error {
|
||||
ret := _m.Called(_a0, _a1, _a2)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *config.Config, *logrus.Logger) error); ok {
|
||||
r0 = rf(_a0, _a1, _a2)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
100
internal/apimodule/account/account.go
Normal file
100
internal/apimodule/account/account.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
idKey = "id"
|
||||
basePath = "/api/v1/accounts"
|
||||
basePathWithID = basePath + "/:" + idKey
|
||||
verifyPath = basePath + "/verify_credentials"
|
||||
updateCredentialsPath = basePath + "/update_credentials"
|
||||
)
|
||||
|
||||
type accountModule struct {
|
||||
config *config.Config
|
||||
db db.DB
|
||||
oauthServer oauth.Server
|
||||
mediaHandler media.MediaHandler
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new account module
|
||||
func New(config *config.Config, db db.DB, oauthServer oauth.Server, mediaHandler media.MediaHandler, log *logrus.Logger) apimodule.ClientAPIModule {
|
||||
return &accountModule{
|
||||
config: config,
|
||||
db: db,
|
||||
oauthServer: oauthServer,
|
||||
mediaHandler: mediaHandler,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Route attaches all routes from this module to the given router
|
||||
func (m *accountModule) Route(r router.Router) error {
|
||||
r.AttachHandler(http.MethodPost, basePath, m.accountCreatePOSTHandler)
|
||||
r.AttachHandler(http.MethodGet, basePathWithID, m.muxHandler)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *accountModule) CreateTables(db db.DB) error {
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if err := db.CreateTable(m); err != nil {
|
||||
return fmt.Errorf("error creating table: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *accountModule) muxHandler(c *gin.Context) {
|
||||
ru := c.Request.RequestURI
|
||||
if strings.HasPrefix(ru, verifyPath) {
|
||||
m.accountVerifyGETHandler(c)
|
||||
} else if strings.HasPrefix(ru, updateCredentialsPath) {
|
||||
m.accountUpdateCredentialsPATCHHandler(c)
|
||||
} else {
|
||||
m.accountGETHandler(c)
|
||||
}
|
||||
}
|
155
internal/apimodule/account/accountcreate.go
Normal file
155
internal/apimodule/account/accountcreate.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
)
|
||||
|
||||
// accountCreatePOSTHandler handles create account requests, validates them,
|
||||
// and puts them in the database if they're valid.
|
||||
// It should be served as a POST at /api/v1/accounts
|
||||
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "accountCreatePOSTHandler")
|
||||
authed, err := oauth.MustAuth(c, true, true, false, false)
|
||||
if err != nil {
|
||||
l.Debugf("couldn't auth: %s", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Trace("parsing request form")
|
||||
form := &mastotypes.AccountCreateRequest{}
|
||||
if err := c.ShouldBind(form); err != nil || form == nil {
|
||||
l.Debugf("could not parse form from request: %s", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"})
|
||||
return
|
||||
}
|
||||
|
||||
l.Tracef("validating form %+v", form)
|
||||
if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
|
||||
l.Debugf("error validating form: %s", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
l.Tracef("attempting to parse client ip address %s", clientIP)
|
||||
signUpIP := net.ParseIP(clientIP)
|
||||
if signUpIP == nil {
|
||||
l.Debugf("error validating sign up ip address %s", clientIP)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
|
||||
return
|
||||
}
|
||||
|
||||
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
|
||||
if err != nil {
|
||||
l.Errorf("internal server error while creating new account: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ti)
|
||||
}
|
||||
|
||||
// accountCreate does the dirty work of making an account and user in the database.
|
||||
// It then returns a token to the caller, for use with the new account, as per the
|
||||
// spec here: https://docs.joinmastodon.org/methods/accounts/
|
||||
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
|
||||
l := m.log.WithField("func", "accountCreate")
|
||||
|
||||
// don't store a reason if we don't require one
|
||||
reason := form.Reason
|
||||
if !m.config.AccountsConfig.ReasonRequired {
|
||||
reason = ""
|
||||
}
|
||||
|
||||
l.Trace("creating new username and account")
|
||||
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale, app.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
|
||||
}
|
||||
|
||||
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
|
||||
accessToken, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
|
||||
}
|
||||
|
||||
return &mastotypes.Token{
|
||||
AccessToken: accessToken.GetAccess(),
|
||||
TokenType: "Bearer",
|
||||
Scope: accessToken.GetScope(),
|
||||
CreatedAt: accessToken.GetAccessCreateAt().Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateCreateAccount checks through all the necessary prerequisites for creating a new account,
|
||||
// according to the provided account create request. If the account isn't eligible, an error will be returned.
|
||||
func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
|
||||
if !c.OpenRegistration {
|
||||
return errors.New("registration is not open for this server")
|
||||
}
|
||||
|
||||
if err := util.ValidateUsername(form.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := util.ValidateEmail(form.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := util.ValidateNewPassword(form.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !form.Agreement {
|
||||
return errors.New("agreement to terms and conditions not given")
|
||||
}
|
||||
|
||||
if err := util.ValidateLanguage(form.Locale); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := database.IsEmailAvailable(form.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := database.IsUsernameAvailable(form.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
545
internal/apimodule/account/accountcreate_test.go
Normal file
545
internal/apimodule/account/accountcreate_test.go
Normal file
|
@ -0,0 +1,545 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
oauthmodels "github.com/superseriousbusiness/oauth2/v4/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AccountCreateTestSuite struct {
|
||||
suite.Suite
|
||||
config *config.Config
|
||||
log *logrus.Logger
|
||||
testAccountLocal *model.Account
|
||||
testApplication *model.Application
|
||||
testToken oauth2.TokenInfo
|
||||
mockOauthServer *oauth.MockServer
|
||||
mockStorage *storage.MockStorage
|
||||
mediaHandler media.MediaHandler
|
||||
db db.DB
|
||||
accountModule *accountModule
|
||||
newUserFormHappyPath url.Values
|
||||
}
|
||||
|
||||
/*
|
||||
TEST INFRASTRUCTURE
|
||||
*/
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *AccountCreateTestSuite) SetupSuite() {
|
||||
// some of our subsequent entities need a log so create this here
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
suite.log = log
|
||||
|
||||
suite.testAccountLocal = &model.Account{
|
||||
ID: uuid.NewString(),
|
||||
Username: "test_user",
|
||||
}
|
||||
|
||||
// can use this test application throughout
|
||||
suite.testApplication = &model.Application{
|
||||
ID: "weeweeeeeeeeeeeeee",
|
||||
Name: "a test application",
|
||||
Website: "https://some-application-website.com",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
ClientID: "a-known-client-id",
|
||||
ClientSecret: "some-secret",
|
||||
Scopes: "read",
|
||||
VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa",
|
||||
}
|
||||
|
||||
// can use this test token throughout
|
||||
suite.testToken = &oauthmodels.Token{
|
||||
ClientID: "a-known-client-id",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
Scope: "read",
|
||||
Code: "123456789",
|
||||
CodeCreateAt: time.Now(),
|
||||
CodeExpiresIn: time.Duration(10 * time.Minute),
|
||||
}
|
||||
|
||||
// Direct config to local postgres instance
|
||||
c := config.Empty()
|
||||
c.Protocol = "http"
|
||||
c.Host = "localhost"
|
||||
c.DBConfig = &config.DBConfig{
|
||||
Type: "postgres",
|
||||
Address: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres",
|
||||
Database: "postgres",
|
||||
ApplicationName: "gotosocial",
|
||||
}
|
||||
c.MediaConfig = &config.MediaConfig{
|
||||
MaxImageSize: 2 << 20,
|
||||
}
|
||||
c.StorageConfig = &config.StorageConfig{
|
||||
Backend: "local",
|
||||
BasePath: "/tmp",
|
||||
ServeProtocol: "http",
|
||||
ServeHost: "localhost",
|
||||
ServeBasePath: "/fileserver/media",
|
||||
}
|
||||
suite.config = c
|
||||
|
||||
// use an actual database for this, because it's just easier than mocking one out
|
||||
database, err := db.New(context.Background(), c, log)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.db = database
|
||||
|
||||
// we need to mock the oauth server because account creation needs it to create a new token
|
||||
suite.mockOauthServer = &oauth.MockServer{}
|
||||
suite.mockOauthServer.On("GenerateUserAccessToken", suite.testToken, suite.testApplication.ClientSecret, mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
|
||||
l := suite.log.WithField("func", "GenerateUserAccessToken")
|
||||
token := args.Get(0).(oauth2.TokenInfo)
|
||||
l.Infof("received token %+v", token)
|
||||
clientSecret := args.Get(1).(string)
|
||||
l.Infof("received clientSecret %+v", clientSecret)
|
||||
userID := args.Get(2).(string)
|
||||
l.Infof("received userID %+v", userID)
|
||||
}).Return(&models.Token{
|
||||
Code: "we're authorized now!",
|
||||
}, nil)
|
||||
|
||||
suite.mockStorage = &storage.MockStorage{}
|
||||
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
|
||||
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
|
||||
|
||||
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
|
||||
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
|
||||
|
||||
// and finally here's the thing we're actually testing!
|
||||
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
|
||||
}
|
||||
|
||||
func (suite *AccountCreateTestSuite) TearDownSuite() {
|
||||
if err := suite.db.Stop(context.Background()); err != nil {
|
||||
logrus.Panicf("error closing db connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest creates a db connection and creates necessary tables before each test
|
||||
func (suite *AccountCreateTestSuite) SetupTest() {
|
||||
// create all the tables we might need in thie suite
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.CreateTable(m); err != nil {
|
||||
logrus.Panicf("db connection error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// form to submit for happy path account create requests -- this will be changed inside tests so it's better to set it before each test
|
||||
suite.newUserFormHappyPath = url.Values{
|
||||
"reason": []string{"a very good reason that's at least 40 characters i swear"},
|
||||
"username": []string{"test_user"},
|
||||
"email": []string{"user@example.org"},
|
||||
"password": []string{"very-strong-password"},
|
||||
"agreement": []string{"true"},
|
||||
"locale": []string{"en"},
|
||||
}
|
||||
|
||||
// same with accounts config
|
||||
suite.config.AccountsConfig = &config.AccountsConfig{
|
||||
OpenRegistration: true,
|
||||
RequireApproval: true,
|
||||
ReasonRequired: true,
|
||||
}
|
||||
}
|
||||
|
||||
// TearDownTest drops tables to make sure there's no data in the db
|
||||
func (suite *AccountCreateTestSuite) TearDownTest() {
|
||||
|
||||
// remove all the tables we might have used so it's clear for the next test
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
logrus.Panicf("error dropping table: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
ACTUAL TESTS
|
||||
*/
|
||||
|
||||
/*
|
||||
TESTING: AccountCreatePOSTHandler
|
||||
*/
|
||||
|
||||
// TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid,
|
||||
// and at the end of it a new user and account should be added into the database.
|
||||
//
|
||||
// This is the handler served at /api/v1/accounts as POST
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerSuccessful() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
|
||||
// 1. we should have OK from our call to the function
|
||||
suite.EqualValues(http.StatusOK, recorder.Code)
|
||||
|
||||
// 2. we should have a token in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
t := &mastotypes.Token{}
|
||||
err = json.Unmarshal(b, t)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), "we're authorized now!", t.AccessToken)
|
||||
|
||||
// check new account
|
||||
|
||||
// 1. we should be able to get the new account from the db
|
||||
acct := &model.Account{}
|
||||
err = suite.db.GetWhere("username", "test_user", acct)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.NotNil(suite.T(), acct)
|
||||
// 2. reason should be set
|
||||
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("reason"), acct.Reason)
|
||||
// 3. display name should be equal to username by default
|
||||
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("username"), acct.DisplayName)
|
||||
// 4. domain should be nil because this is a local account
|
||||
assert.Nil(suite.T(), nil, acct.Domain)
|
||||
// 5. id should be set and parseable as a uuid
|
||||
assert.NotNil(suite.T(), acct.ID)
|
||||
_, err = uuid.Parse(acct.ID)
|
||||
assert.Nil(suite.T(), err)
|
||||
// 6. private and public key should be set
|
||||
assert.NotNil(suite.T(), acct.PrivateKey)
|
||||
assert.NotNil(suite.T(), acct.PublicKey)
|
||||
|
||||
// check new user
|
||||
|
||||
// 1. we should be able to get the new user from the db
|
||||
usr := &model.User{}
|
||||
err = suite.db.GetWhere("unconfirmed_email", suite.newUserFormHappyPath.Get("email"), usr)
|
||||
assert.Nil(suite.T(), err)
|
||||
assert.NotNil(suite.T(), usr)
|
||||
|
||||
// 2. user should have account id set to account we got above
|
||||
assert.Equal(suite.T(), acct.ID, usr.AccountID)
|
||||
|
||||
// 3. id should be set and parseable as a uuid
|
||||
assert.NotNil(suite.T(), usr.ID)
|
||||
_, err = uuid.Parse(usr.ID)
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
// 4. locale should be equal to what we requested
|
||||
assert.Equal(suite.T(), suite.newUserFormHappyPath.Get("locale"), usr.Locale)
|
||||
|
||||
// 5. created by application id should be equal to the app id
|
||||
assert.Equal(suite.T(), suite.testApplication.ID, usr.CreatedByApplicationID)
|
||||
|
||||
// 6. password should be matcheable to what we set above
|
||||
err = bcrypt.CompareHashAndPassword([]byte(usr.EncryptedPassword), []byte(suite.newUserFormHappyPath.Get("password")))
|
||||
assert.Nil(suite.T(), err)
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no authorization is provided:
|
||||
// only registered applications can create accounts, and we don't provide one here.
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoAuth() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
|
||||
// 1. we should have forbidden from our call to the function because we didn't auth
|
||||
suite.EqualValues(http.StatusForbidden, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no form is provided at all.
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoForm() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"missing one or more required form values"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerWeakPassword makes sure that the handler fails when a weak password is provided
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeakPassword() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
// set a weak password
|
||||
ctx.Request.Form.Set("password", "weak")
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"insecure password, try including more special characters, using uppercase letters, using numbers or using a longer password"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerWeirdLocale makes sure that the handler fails when a weird locale is provided
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeirdLocale() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
// set an invalid locale
|
||||
ctx.Request.Form.Set("locale", "neverneverland")
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"language: tag is not well-formed"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerRegistrationsClosed makes sure that the handler fails when registrations are closed
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerRegistrationsClosed() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
|
||||
// close registrations
|
||||
suite.config.AccountsConfig.OpenRegistration = false
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"registration is not open for this server"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when no reason is provided but one is required
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerReasonNotProvided() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
|
||||
// remove reason
|
||||
ctx.Request.Form.Set("reason", "")
|
||||
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"no reason provided"}`, string(b))
|
||||
}
|
||||
|
||||
// TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when a crappy reason is presented but a good one is required
|
||||
func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() {
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
|
||||
ctx.Request.Form = suite.newUserFormHappyPath
|
||||
|
||||
// remove reason
|
||||
ctx.Request.Form.Set("reason", "just cuz")
|
||||
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
|
||||
// check response
|
||||
suite.EqualValues(http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
b, err := ioutil.ReadAll(result.Body)
|
||||
assert.NoError(suite.T(), err)
|
||||
assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b))
|
||||
}
|
||||
|
||||
/*
|
||||
TESTING: AccountUpdateCredentialsPATCHHandler
|
||||
*/
|
||||
|
||||
func (suite *AccountCreateTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
|
||||
|
||||
// put test local account in db
|
||||
err := suite.db.Put(suite.testAccountLocal)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// attach avatar to request
|
||||
aviFile, err := os.Open("../../media/test/test-jpeg.jpg")
|
||||
assert.NoError(suite.T(), err)
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
_, err = io.Copy(part, aviFile)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
err = aviFile.Close()
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
err = writer.Close()
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
|
||||
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
|
||||
|
||||
// check response
|
||||
|
||||
// 1. we should have OK because our request was valid
|
||||
suite.EqualValues(http.StatusOK, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
// TODO: implement proper checks here
|
||||
//
|
||||
// b, err := ioutil.ReadAll(result.Body)
|
||||
// assert.NoError(suite.T(), err)
|
||||
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
|
||||
}
|
||||
|
||||
func TestAccountCreateTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountCreateTestSuite))
|
||||
}
|
57
internal/apimodule/account/accountget.go
Normal file
57
internal/apimodule/account/accountget.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
)
|
||||
|
||||
// accountGetHandler serves the account information held by the server in response to a GET
|
||||
// request. It should be served as a GET at /api/v1/accounts/:id.
|
||||
//
|
||||
// See: https://docs.joinmastodon.org/methods/accounts/
|
||||
func (m *accountModule) accountGETHandler(c *gin.Context) {
|
||||
targetAcctID := c.Param(idKey)
|
||||
if targetAcctID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no account id specified"})
|
||||
return
|
||||
}
|
||||
|
||||
targetAccount := &model.Account{}
|
||||
if err := m.db.GetByID(targetAcctID, targetAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Record not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
acctInfo, err := m.db.AccountToMastoPublic(targetAccount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, acctInfo)
|
||||
}
|
259
internal/apimodule/account/accountupdate.go
Normal file
259
internal/apimodule/account/accountupdate.go
Normal file
|
@ -0,0 +1,259 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
)
|
||||
|
||||
// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings.
|
||||
// It should be served as a PATCH at /api/v1/accounts/update_credentials
|
||||
//
|
||||
// TODO: this can be optimized massively by building up a picture of what we want the new account
|
||||
// details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one
|
||||
// which is not gonna make the database very happy when lots of requests are going through.
|
||||
// This way it would also be safer because the update won't happen until *all* the fields are validated.
|
||||
// Otherwise we risk doing a partial update and that's gonna cause probllleeemmmsss.
|
||||
func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler")
|
||||
authed, err := oauth.MustAuth(c, true, false, false, true)
|
||||
if err != nil {
|
||||
l.Debugf("couldn't auth: %s", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
l.Tracef("retrieved account %+v", authed.Account.ID)
|
||||
|
||||
l.Trace("parsing request form")
|
||||
form := &mastotypes.UpdateCredentialsRequest{}
|
||||
if err := c.ShouldBind(form); err != nil || form == nil {
|
||||
l.Debugf("could not parse form from request: %s", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// if everything on the form is nil, then nothing has been set and we shouldn't continue
|
||||
if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil {
|
||||
l.Debugf("could not parse form from request")
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"})
|
||||
return
|
||||
}
|
||||
|
||||
if form.Discoverable != nil {
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil {
|
||||
l.Debugf("error updating discoverable: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Bot != nil {
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil {
|
||||
l.Debugf("error updating bot: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.DisplayName != nil {
|
||||
if err := util.ValidateDisplayName(*form.DisplayName); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Note != nil {
|
||||
if err := util.ValidateNote(*form.Note); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil {
|
||||
l.Debugf("error updating note: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Avatar != nil && form.Avatar.Size != 0 {
|
||||
avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID)
|
||||
if err != nil {
|
||||
l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo)
|
||||
}
|
||||
|
||||
if form.Header != nil && form.Header.Size != 0 {
|
||||
headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID)
|
||||
if err != nil {
|
||||
l.Debugf("could not update header for account %s: %s", authed.Account.ID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo)
|
||||
}
|
||||
|
||||
if form.Locked != nil {
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Source != nil {
|
||||
if form.Source.Language != nil {
|
||||
if err := util.ValidateLanguage(*form.Source.Language); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "language", *form.Source.Language, &model.Account{}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Source.Sensitive != nil {
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if form.Source.Privacy != nil {
|
||||
if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := m.db.UpdateOneByID(authed.Account.ID, "privacy", *form.Source.Privacy, &model.Account{}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if form.FieldsAttributes != nil {
|
||||
// // TODO: parse fields attributes nicely and update
|
||||
// }
|
||||
|
||||
// fetch the account with all updated values set
|
||||
updatedAccount := &model.Account{}
|
||||
if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil {
|
||||
l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount)
|
||||
if err != nil {
|
||||
l.Tracef("could not convert account into mastosensitive account: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
||||
c.JSON(http.StatusOK, acctSensitive)
|
||||
}
|
||||
|
||||
/*
|
||||
HELPER FUNCTIONS
|
||||
*/
|
||||
|
||||
// TODO: try to combine the below two functions because this is a lot of code repetition.
|
||||
|
||||
// UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form,
|
||||
// parsing and checking the image, and doing the necessary updates in the database for this to become
|
||||
// the account's new avatar image.
|
||||
func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
|
||||
var err error
|
||||
if int(avatar.Size) > m.config.MediaConfig.MaxImageSize {
|
||||
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize)
|
||||
return nil, err
|
||||
}
|
||||
f, err := avatar.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read provided avatar: %s", err)
|
||||
}
|
||||
|
||||
// extract the bytes
|
||||
buf := new(bytes.Buffer)
|
||||
size, err := io.Copy(buf, f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read provided avatar: %s", err)
|
||||
}
|
||||
if size == 0 {
|
||||
return nil, errors.New("could not read provided avatar: size 0 bytes")
|
||||
}
|
||||
|
||||
// do the setting
|
||||
avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error processing avatar: %s", err)
|
||||
}
|
||||
|
||||
return avatarInfo, f.Close()
|
||||
}
|
||||
|
||||
// UpdateAccountHeader does the dirty work of checking the header part of an account update form,
|
||||
// parsing and checking the image, and doing the necessary updates in the database for this to become
|
||||
// the account's new header image.
|
||||
func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
|
||||
var err error
|
||||
if int(header.Size) > m.config.MediaConfig.MaxImageSize {
|
||||
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize)
|
||||
return nil, err
|
||||
}
|
||||
f, err := header.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read provided header: %s", err)
|
||||
}
|
||||
|
||||
// extract the bytes
|
||||
buf := new(bytes.Buffer)
|
||||
size, err := io.Copy(buf, f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read provided header: %s", err)
|
||||
}
|
||||
if size == 0 {
|
||||
return nil, errors.New("could not read provided header: size 0 bytes")
|
||||
}
|
||||
|
||||
// do the setting
|
||||
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error processing header: %s", err)
|
||||
}
|
||||
|
||||
return headerInfo, f.Close()
|
||||
}
|
298
internal/apimodule/account/accountupdate_test.go
Normal file
298
internal/apimodule/account/accountupdate_test.go
Normal file
|
@ -0,0 +1,298 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
oauthmodels "github.com/superseriousbusiness/oauth2/v4/models"
|
||||
)
|
||||
|
||||
type AccountUpdateTestSuite struct {
|
||||
suite.Suite
|
||||
config *config.Config
|
||||
log *logrus.Logger
|
||||
testAccountLocal *model.Account
|
||||
testApplication *model.Application
|
||||
testToken oauth2.TokenInfo
|
||||
mockOauthServer *oauth.MockServer
|
||||
mockStorage *storage.MockStorage
|
||||
mediaHandler media.MediaHandler
|
||||
db db.DB
|
||||
accountModule *accountModule
|
||||
newUserFormHappyPath url.Values
|
||||
}
|
||||
|
||||
/*
|
||||
TEST INFRASTRUCTURE
|
||||
*/
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *AccountUpdateTestSuite) SetupSuite() {
|
||||
// some of our subsequent entities need a log so create this here
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
suite.log = log
|
||||
|
||||
suite.testAccountLocal = &model.Account{
|
||||
ID: uuid.NewString(),
|
||||
Username: "test_user",
|
||||
}
|
||||
|
||||
// can use this test application throughout
|
||||
suite.testApplication = &model.Application{
|
||||
ID: "weeweeeeeeeeeeeeee",
|
||||
Name: "a test application",
|
||||
Website: "https://some-application-website.com",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
ClientID: "a-known-client-id",
|
||||
ClientSecret: "some-secret",
|
||||
Scopes: "read",
|
||||
VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa",
|
||||
}
|
||||
|
||||
// can use this test token throughout
|
||||
suite.testToken = &oauthmodels.Token{
|
||||
ClientID: "a-known-client-id",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
Scope: "read",
|
||||
Code: "123456789",
|
||||
CodeCreateAt: time.Now(),
|
||||
CodeExpiresIn: time.Duration(10 * time.Minute),
|
||||
}
|
||||
|
||||
// Direct config to local postgres instance
|
||||
c := config.Empty()
|
||||
c.Protocol = "http"
|
||||
c.Host = "localhost"
|
||||
c.DBConfig = &config.DBConfig{
|
||||
Type: "postgres",
|
||||
Address: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres",
|
||||
Database: "postgres",
|
||||
ApplicationName: "gotosocial",
|
||||
}
|
||||
c.MediaConfig = &config.MediaConfig{
|
||||
MaxImageSize: 2 << 20,
|
||||
}
|
||||
c.StorageConfig = &config.StorageConfig{
|
||||
Backend: "local",
|
||||
BasePath: "/tmp",
|
||||
ServeProtocol: "http",
|
||||
ServeHost: "localhost",
|
||||
ServeBasePath: "/fileserver/media",
|
||||
}
|
||||
suite.config = c
|
||||
|
||||
// use an actual database for this, because it's just easier than mocking one out
|
||||
database, err := db.New(context.Background(), c, log)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.db = database
|
||||
|
||||
// we need to mock the oauth server because account creation needs it to create a new token
|
||||
suite.mockOauthServer = &oauth.MockServer{}
|
||||
suite.mockOauthServer.On("GenerateUserAccessToken", suite.testToken, suite.testApplication.ClientSecret, mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
|
||||
l := suite.log.WithField("func", "GenerateUserAccessToken")
|
||||
token := args.Get(0).(oauth2.TokenInfo)
|
||||
l.Infof("received token %+v", token)
|
||||
clientSecret := args.Get(1).(string)
|
||||
l.Infof("received clientSecret %+v", clientSecret)
|
||||
userID := args.Get(2).(string)
|
||||
l.Infof("received userID %+v", userID)
|
||||
}).Return(&models.Token{
|
||||
Code: "we're authorized now!",
|
||||
}, nil)
|
||||
|
||||
suite.mockStorage = &storage.MockStorage{}
|
||||
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
|
||||
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
|
||||
|
||||
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
|
||||
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
|
||||
|
||||
// and finally here's the thing we're actually testing!
|
||||
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
|
||||
}
|
||||
|
||||
func (suite *AccountUpdateTestSuite) TearDownSuite() {
|
||||
if err := suite.db.Stop(context.Background()); err != nil {
|
||||
logrus.Panicf("error closing db connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest creates a db connection and creates necessary tables before each test
|
||||
func (suite *AccountUpdateTestSuite) SetupTest() {
|
||||
// create all the tables we might need in thie suite
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.CreateTable(m); err != nil {
|
||||
logrus.Panicf("db connection error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// form to submit for happy path account create requests -- this will be changed inside tests so it's better to set it before each test
|
||||
suite.newUserFormHappyPath = url.Values{
|
||||
"reason": []string{"a very good reason that's at least 40 characters i swear"},
|
||||
"username": []string{"test_user"},
|
||||
"email": []string{"user@example.org"},
|
||||
"password": []string{"very-strong-password"},
|
||||
"agreement": []string{"true"},
|
||||
"locale": []string{"en"},
|
||||
}
|
||||
|
||||
// same with accounts config
|
||||
suite.config.AccountsConfig = &config.AccountsConfig{
|
||||
OpenRegistration: true,
|
||||
RequireApproval: true,
|
||||
ReasonRequired: true,
|
||||
}
|
||||
}
|
||||
|
||||
// TearDownTest drops tables to make sure there's no data in the db
|
||||
func (suite *AccountUpdateTestSuite) TearDownTest() {
|
||||
|
||||
// remove all the tables we might have used so it's clear for the next test
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
logrus.Panicf("error dropping table: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
ACTUAL TESTS
|
||||
*/
|
||||
|
||||
/*
|
||||
TESTING: AccountUpdateCredentialsPATCHHandler
|
||||
*/
|
||||
|
||||
func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
|
||||
|
||||
// put test local account in db
|
||||
err := suite.db.Put(suite.testAccountLocal)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// attach avatar to request form
|
||||
avatarFile, err := os.Open("../../media/test/test-jpeg.jpg")
|
||||
assert.NoError(suite.T(), err)
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
avatarPart, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
_, err = io.Copy(avatarPart, avatarFile)
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
err = avatarFile.Close()
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// set display name to a new value
|
||||
displayNamePart, err := writer.CreateFormField("display_name")
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
_, err = io.Copy(displayNamePart, bytes.NewBufferString("test_user_wohoah"))
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// set locked to true
|
||||
lockedPart, err := writer.CreateFormField("locked")
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
_, err = io.Copy(lockedPart, bytes.NewBufferString("true"))
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// close the request writer, the form is now prepared
|
||||
err = writer.Close()
|
||||
assert.NoError(suite.T(), err)
|
||||
|
||||
// setup
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
|
||||
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
|
||||
|
||||
// check response
|
||||
|
||||
// 1. we should have OK because our request was valid
|
||||
suite.EqualValues(http.StatusOK, recorder.Code)
|
||||
|
||||
// 2. we should have an error message in the result body
|
||||
result := recorder.Result()
|
||||
defer result.Body.Close()
|
||||
// TODO: implement proper checks here
|
||||
//
|
||||
// b, err := ioutil.ReadAll(result.Body)
|
||||
// assert.NoError(suite.T(), err)
|
||||
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
|
||||
}
|
||||
|
||||
func TestAccountUpdateTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountUpdateTestSuite))
|
||||
}
|
50
internal/apimodule/account/accountverify.go
Normal file
50
internal/apimodule/account/accountverify.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
)
|
||||
|
||||
// accountVerifyGETHandler serves a user's account details to them IF they reached this
|
||||
// handler while in possession of a valid token, according to the oauth middleware.
|
||||
// It should be served as a GET at /api/v1/accounts/verify_credentials
|
||||
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "accountVerifyGETHandler")
|
||||
authed, err := oauth.MustAuth(c, true, false, false, true)
|
||||
if err != nil {
|
||||
l.Debugf("couldn't auth: %s", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Tracef("retrieved account %+v, converting to mastosensitive...", authed.Account.ID)
|
||||
acctSensitive, err := m.db.AccountToMastoSensitive(authed.Account)
|
||||
if err != nil {
|
||||
l.Tracef("could not convert account into mastosensitive account: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
||||
c.JSON(http.StatusOK, acctSensitive)
|
||||
}
|
|
@ -17,21 +17,3 @@
|
|||
*/
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"github.com/gotosocial/gotosocial/internal/module"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
type accountModule struct {
|
||||
}
|
||||
|
||||
// New returns a new account module
|
||||
func New() module.ClientAPIModule {
|
||||
return &accountModule{}
|
||||
}
|
||||
|
||||
// Route attaches all routes from this module to the given router
|
||||
func (m *accountModule) Route(r router.Router) error {
|
||||
return nil
|
||||
}
|
|
@ -16,14 +16,18 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
|
||||
package module
|
||||
// Package apimodule is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
|
||||
package apimodule
|
||||
|
||||
import "github.com/gotosocial/gotosocial/internal/router"
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set
|
||||
// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;)
|
||||
// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/
|
||||
type ClientAPIModule interface {
|
||||
Route(s router.Router) error
|
||||
CreateTables(db db.DB) error
|
||||
}
|
71
internal/apimodule/app/app.go
Normal file
71
internal/apimodule/app/app.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
const appsPath = "/api/v1/apps"
|
||||
|
||||
type appModule struct {
|
||||
server oauth.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new auth module
|
||||
func New(srv oauth.Server, db db.DB, log *logrus.Logger) apimodule.ClientAPIModule {
|
||||
return &appModule{
|
||||
server: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *appModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *appModule) CreateTables(db db.DB) error {
|
||||
models := []interface{}{
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if err := db.CreateTable(m); err != nil {
|
||||
return fmt.Errorf("error creating table: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
21
internal/apimodule/app/app_test.go
Normal file
21
internal/apimodule/app/app_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
// TODO: write tests
|
113
internal/apimodule/app/appcreate.go
Normal file
113
internal/apimodule/app/appcreate.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
)
|
||||
|
||||
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||
func (m *appModule) appsPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||
l.Trace("entering AppsPOSTHandler")
|
||||
|
||||
form := &mastotypes.ApplicationPOSTRequest{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// permitted length for most fields
|
||||
permittedLength := 64
|
||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||
permittedRedirect := 256
|
||||
|
||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||
if len(form.ClientName) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.Website) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.RedirectURIs) > permittedRedirect {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||
return
|
||||
}
|
||||
if len(form.Scopes) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
|
||||
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||
var scopes string
|
||||
if form.Scopes == "" {
|
||||
scopes = "read"
|
||||
} else {
|
||||
scopes = form.Scopes
|
||||
}
|
||||
|
||||
// generate new IDs for this application and its associated client
|
||||
clientID := uuid.NewString()
|
||||
clientSecret := uuid.NewString()
|
||||
vapidKey := uuid.NewString()
|
||||
|
||||
// generate the application to put in the database
|
||||
app := &model.Application{
|
||||
Name: form.ClientName,
|
||||
Website: form.Website,
|
||||
RedirectURI: form.RedirectURIs,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: scopes,
|
||||
VapidKey: vapidKey,
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// now we need to model an oauth client from the application that the oauth library can use
|
||||
oc := &oauth.Client{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Domain: form.RedirectURIs,
|
||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(oc); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||
c.JSON(http.StatusOK, app.ToMasto())
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
# oauth
|
||||
# auth
|
||||
|
||||
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
|
||||
|
89
internal/apimodule/auth/auth.go
Normal file
89
internal/apimodule/auth/auth.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Package auth is a module that provides oauth functionality to a router.
|
||||
// It adds the following paths:
|
||||
// /auth/sign_in
|
||||
// /oauth/token
|
||||
// /oauth/authorize
|
||||
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
authSignInPath = "/auth/sign_in"
|
||||
oauthTokenPath = "/oauth/token"
|
||||
oauthAuthorizePath = "/oauth/authorize"
|
||||
)
|
||||
|
||||
type authModule struct {
|
||||
server oauth.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new auth module
|
||||
func New(srv oauth.Server, db db.DB, log *logrus.Logger) apimodule.ClientAPIModule {
|
||||
return &authModule{
|
||||
server: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *authModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
||||
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
||||
|
||||
s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
|
||||
|
||||
s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
|
||||
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
||||
|
||||
s.AttachMiddleware(m.oauthTokenMiddleware)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *authModule) CreateTables(db db.DB) error {
|
||||
models := []interface{}{
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if err := db.CreateTable(m); err != nil {
|
||||
return fmt.Errorf("error creating table: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -16,7 +16,7 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -25,30 +25,29 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type OauthTestSuite struct {
|
||||
type AuthTestSuite struct {
|
||||
suite.Suite
|
||||
tokenStore oauth2.TokenStore
|
||||
clientStore oauth2.ClientStore
|
||||
oauthServer oauth.Server
|
||||
db db.DB
|
||||
testAccount *gtsmodel.Account
|
||||
testApplication *gtsmodel.Application
|
||||
testUser *gtsmodel.User
|
||||
testClient *oauthClient
|
||||
testAccount *model.Account
|
||||
testApplication *model.Application
|
||||
testUser *model.User
|
||||
testClient *oauth.Client
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *OauthTestSuite) SetupSuite() {
|
||||
func (suite *AuthTestSuite) SetupSuite() {
|
||||
c := config.Empty()
|
||||
// we're running on localhost without https so set the protocol to http
|
||||
c.Protocol = "http"
|
||||
|
@ -76,21 +75,21 @@ func (suite *OauthTestSuite) SetupSuite() {
|
|||
|
||||
acctID := uuid.NewString()
|
||||
|
||||
suite.testAccount = >smodel.Account{
|
||||
suite.testAccount = &model.Account{
|
||||
ID: acctID,
|
||||
Username: "test_user",
|
||||
}
|
||||
suite.testUser = >smodel.User{
|
||||
suite.testUser = &model.User{
|
||||
EncryptedPassword: string(encryptedPassword),
|
||||
Email: "user@example.org",
|
||||
AccountID: acctID,
|
||||
}
|
||||
suite.testClient = &oauthClient{
|
||||
suite.testClient = &oauth.Client{
|
||||
ID: "a-known-client-id",
|
||||
Secret: "some-secret",
|
||||
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
||||
}
|
||||
suite.testApplication = >smodel.Application{
|
||||
suite.testApplication = &model.Application{
|
||||
Name: "a test application",
|
||||
Website: "https://some-application-website.com",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
|
@ -102,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
|||
}
|
||||
|
||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||
func (suite *OauthTestSuite) SetupTest() {
|
||||
func (suite *AuthTestSuite) SetupTest() {
|
||||
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
@ -114,11 +113,11 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
suite.db = db
|
||||
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&oauthToken{},
|
||||
>smodel.User{},
|
||||
>smodel.Account{},
|
||||
>smodel.Application{},
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
|
@ -127,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
}
|
||||
}
|
||||
|
||||
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
|
||||
suite.clientStore = newClientStore(suite.db)
|
||||
suite.oauthServer = oauth.New(suite.db, log)
|
||||
|
||||
if err := suite.db.Put(suite.testAccount); err != nil {
|
||||
logrus.Panicf("could not insert test account into db: %s", err)
|
||||
|
@ -146,13 +144,13 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
}
|
||||
|
||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
func (suite *OauthTestSuite) TearDownTest() {
|
||||
func (suite *AuthTestSuite) TearDownTest() {
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&oauthToken{},
|
||||
>smodel.User{},
|
||||
>smodel.Account{},
|
||||
>smodel.Application{},
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
|
@ -165,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
|
|||
suite.db = nil
|
||||
}
|
||||
|
||||
func (suite *OauthTestSuite) TestAPIInitialize() {
|
||||
func (suite *AuthTestSuite) TestAPIInitialize() {
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
||||
|
@ -174,18 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
|
|||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||
}
|
||||
|
||||
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
|
||||
api := New(suite.oauthServer, suite.db, log)
|
||||
if err := api.Route(r); err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||
}
|
||||
|
||||
go r.Start()
|
||||
r.Start()
|
||||
time.Sleep(60 * time.Second)
|
||||
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
|
||||
// curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
|
||||
// curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
|
||||
if err := r.Stop(context.Background()); err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(OauthTestSuite))
|
||||
func TestAuthTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthTestSuite))
|
||||
}
|
204
internal/apimodule/auth/authorize.go
Normal file
204
internal/apimodule/auth/authorize.go
Normal file
|
@ -0,0 +1,204 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
)
|
||||
|
||||
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
||||
// The idea here is to present an oauth authorize page to the user, with a button
|
||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *authModule) authorizeGETHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizeGETHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
|
||||
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
|
||||
userID, ok := s.Get("userid").(string)
|
||||
if !ok || userID == "" {
|
||||
l.Trace("userid was empty, parsing form then redirecting to sign in page")
|
||||
if err := parseAuthForm(c, l); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, authSignInPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// We can use the client_id on the session to retrieve info about the app associated with the client_id
|
||||
clientID, ok := s.Get("client_id").(string)
|
||||
if !ok || clientID == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
||||
return
|
||||
}
|
||||
app := &model.Application{
|
||||
ClientID: clientID,
|
||||
}
|
||||
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||
return
|
||||
}
|
||||
|
||||
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
||||
user := &model.User{
|
||||
ID: userID,
|
||||
}
|
||||
if err := m.db.GetByID(user.ID, user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
acct := &model.Account{
|
||||
ID: user.AccountID,
|
||||
}
|
||||
|
||||
if err := m.db.GetByID(acct.ID, acct); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Finally we should also get the redirect and scope of this particular request, as stored in the session.
|
||||
redirect, ok := s.Get("redirect_uri").(string)
|
||||
if !ok || redirect == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
|
||||
return
|
||||
}
|
||||
scope, ok := s.Get("scope").(string)
|
||||
if !ok || scope == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
|
||||
return
|
||||
}
|
||||
|
||||
// the authorize template will display a form to the user where they can get some information
|
||||
// about the app that's trying to authorize, and the scope of the request.
|
||||
// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
|
||||
l.Trace("serving authorize html")
|
||||
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
|
||||
"appname": app.Name,
|
||||
"appwebsite": app.Website,
|
||||
"redirect": redirect,
|
||||
"scope": scope,
|
||||
"user": acct.Username,
|
||||
})
|
||||
}
|
||||
|
||||
// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
|
||||
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
||||
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
||||
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *authModule) authorizePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
// At this point we know the user has said 'yes' to allowing the application and oauth client
|
||||
// work for them, so we can set the
|
||||
|
||||
// We need to retrieve the original form submitted to the authorizeGEThandler, and
|
||||
// recreate it on the request so that it can be used further by the oauth2 library.
|
||||
// So first fetch all the values from the session.
|
||||
forceLogin, ok := s.Get("force_login").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
|
||||
return
|
||||
}
|
||||
responseType, ok := s.Get("response_type").(string)
|
||||
if !ok || responseType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
|
||||
return
|
||||
}
|
||||
clientID, ok := s.Get("client_id").(string)
|
||||
if !ok || clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
|
||||
return
|
||||
}
|
||||
redirectURI, ok := s.Get("redirect_uri").(string)
|
||||
if !ok || redirectURI == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
|
||||
return
|
||||
}
|
||||
scope, ok := s.Get("scope").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
|
||||
return
|
||||
}
|
||||
userID, ok := s.Get("userid").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
|
||||
return
|
||||
}
|
||||
// we're done with the session so we can clear it now
|
||||
s.Clear()
|
||||
|
||||
// now set the values on the request
|
||||
values := url.Values{}
|
||||
values.Set("force_login", forceLogin)
|
||||
values.Set("response_type", responseType)
|
||||
values.Set("client_id", clientID)
|
||||
values.Set("redirect_uri", redirectURI)
|
||||
values.Set("scope", scope)
|
||||
values.Set("userid", userID)
|
||||
c.Request.Form = values
|
||||
l.Tracef("values on request set to %+v", c.Request.Form)
|
||||
|
||||
// and proceed with authorization using the oauth2 library
|
||||
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
||||
// the values in the form into the session.
|
||||
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
||||
s := sessions.Default(c)
|
||||
|
||||
// first make sure they've filled out the authorize form with the required values
|
||||
form := &mastotypes.OAuthAuthorize{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Tracef("parsed form: %+v", form)
|
||||
|
||||
// these fields are *required* so check 'em
|
||||
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
|
||||
return errors.New("missing one of: response_type, client_id or redirect_uri")
|
||||
}
|
||||
|
||||
// set default scope to read
|
||||
if form.Scope == "" {
|
||||
form.Scope = "read"
|
||||
}
|
||||
|
||||
// save these values from the form so we can use them elsewhere in the session
|
||||
s.Set("force_login", form.ForceLogin)
|
||||
s.Set("response_type", form.ResponseType)
|
||||
s.Set("client_id", form.ClientID)
|
||||
s.Set("redirect_uri", form.RedirectURI)
|
||||
s.Set("scope", form.Scope)
|
||||
return s.Save()
|
||||
}
|
76
internal/apimodule/auth/middleware.go
Normal file
76
internal/apimodule/auth/middleware.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
)
|
||||
|
||||
// oauthTokenMiddleware checks if the client has presented a valid oauth Bearer token.
|
||||
// If so, it will check the User that the token belongs to, and set that in the context of
|
||||
// the request. Then, it will look up the account for that user, and set that in the request too.
|
||||
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
|
||||
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
|
||||
func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
l.Trace("entering OauthTokenMiddleware")
|
||||
|
||||
ti, err := m.server.ValidationBearerToken(c.Request)
|
||||
if err != nil {
|
||||
l.Trace("no valid token presented: continuing with unauthenticated request")
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedToken, ti)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
|
||||
|
||||
// check for user-level token
|
||||
if uid := ti.GetUserID(); uid != "" {
|
||||
l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
|
||||
|
||||
// fetch user's and account for this user id
|
||||
user := &model.User{}
|
||||
if err := m.db.GetByID(uid, user); err != nil || user == nil {
|
||||
l.Warnf("no user found for validated uid %s", uid)
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedUser, user)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
|
||||
|
||||
acct := &model.Account{}
|
||||
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
|
||||
l.Warnf("no account found for validated user %s", uid)
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedAccount, acct)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
|
||||
}
|
||||
|
||||
// check for application token
|
||||
if cid := ti.GetClientID(); cid != "" {
|
||||
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
|
||||
app := &model.Application{}
|
||||
if err := m.db.GetWhere("client_id", cid, app); err != nil {
|
||||
l.Tracef("no app found for client %s", cid)
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedApplication, app)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
|
||||
}
|
||||
}
|
115
internal/apimodule/auth/signin.go
Normal file
115
internal/apimodule/auth/signin.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type login struct {
|
||||
Email string `form:"username"`
|
||||
Password string `form:"password"`
|
||||
}
|
||||
|
||||
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
||||
func (m *authModule) signInGETHandler(c *gin.Context) {
|
||||
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
||||
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
||||
}
|
||||
|
||||
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The handler will then redirect to the auth handler served at /auth
|
||||
func (m *authModule) signInPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "SignInPOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
form := &login{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
l.Tracef("parsed form: %+v", form)
|
||||
|
||||
userid, err := m.validatePassword(form.Email, form.Password)
|
||||
if err != nil {
|
||||
c.String(http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.Set("userid", userid)
|
||||
if err := s.Save(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Trace("redirecting to auth page")
|
||||
c.Redirect(http.StatusFound, oauthAuthorizePath)
|
||||
}
|
||||
|
||||
// validatePassword takes an email address and a password.
|
||||
// The goal is to authenticate the password against the one for that email
|
||||
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||
func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
|
||||
// make sure an email/password was provided and bail if not
|
||||
if email == "" || password == "" {
|
||||
l.Debug("email or password was not provided")
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// first we select the user from the database based on email address, bail if no user found for that email
|
||||
gtsUser := &model.User{}
|
||||
|
||||
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
|
||||
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// make sure a password is actually set and bail if not
|
||||
if gtsUser.EncryptedPassword == "" {
|
||||
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// compare the provided password with the encrypted one from the db, bail if they don't match
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
|
||||
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// If we've made it this far the email/password is correct, so we can just return the id of the user.
|
||||
userid = gtsUser.ID
|
||||
l.Tracef("returning (%s, %s)", userid, err)
|
||||
return
|
||||
}
|
||||
|
||||
// incorrectPassword is just a little helper function to use in the ValidatePassword function
|
||||
func incorrectPassword() (string, error) {
|
||||
return "", errors.New("password/email combination was incorrect")
|
||||
}
|
36
internal/apimodule/auth/token.go
Normal file
36
internal/apimodule/auth/token.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
||||
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
||||
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
||||
func (m *authModule) tokenPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "TokenPOSTHandler")
|
||||
l.Trace("entered TokenPOSTHandler")
|
||||
if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
63
internal/apimodule/fileserver/fileserver.go
Normal file
63
internal/apimodule/fileserver/fileserver.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package fileserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
)
|
||||
|
||||
// fileServer implements the RESTAPIModule interface.
|
||||
// The goal here is to serve requested media files if the gotosocial server is configured to use local storage.
|
||||
type fileServer struct {
|
||||
config *config.Config
|
||||
db db.DB
|
||||
storage storage.Storage
|
||||
log *logrus.Logger
|
||||
storageBase string
|
||||
}
|
||||
|
||||
// New returns a new fileServer module
|
||||
func New(config *config.Config, db db.DB, storage storage.Storage, log *logrus.Logger) apimodule.ClientAPIModule {
|
||||
|
||||
storageBase := config.StorageConfig.BasePath // TODO: do this properly
|
||||
|
||||
return &fileServer{
|
||||
config: config,
|
||||
db: db,
|
||||
storage: storage,
|
||||
log: log,
|
||||
storageBase: storageBase,
|
||||
}
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *fileServer) Route(s router.Router) error {
|
||||
// s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fileServer) CreateTables(db db.DB) error {
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.FollowRequest{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if err := db.CreateTable(m); err != nil {
|
||||
return fmt.Errorf("error creating table: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
27
internal/apimodule/mock_ClientAPIModule.go
Normal file
27
internal/apimodule/mock_ClientAPIModule.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package apimodule
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
router "github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
// MockClientAPIModule is an autogenerated mock type for the ClientAPIModule type
|
||||
type MockClientAPIModule struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Route provides a mock function with given fields: s
|
||||
func (_m *MockClientAPIModule) Route(s router.Router) error {
|
||||
ret := _m.Called(s)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(router.Router) error); ok {
|
||||
r0 = rf(s)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
47
internal/cache/mock_Cache.go
vendored
Normal file
47
internal/cache/mock_Cache.go
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockCache is an autogenerated mock type for the Cache type
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Fetch provides a mock function with given fields: k
|
||||
func (_m *MockCache) Fetch(k string) (interface{}, error) {
|
||||
ret := _m.Called(k)
|
||||
|
||||
var r0 interface{}
|
||||
if rf, ok := ret.Get(0).(func(string) interface{}); ok {
|
||||
r0 = rf(k)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(interface{})
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(k)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Store provides a mock function with given fields: k, v
|
||||
func (_m *MockCache) Store(k string, v interface{}) error {
|
||||
ret := _m.Called(k, v)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
|
||||
r0 = rf(k, v)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
29
internal/config/accounts.go
Normal file
29
internal/config/accounts.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
// AccountsConfig contains configuration to do with creating accounts, new registrations, and defaults.
|
||||
type AccountsConfig struct {
|
||||
// Do we want people to be able to just submit sign up requests, or do we want invite only?
|
||||
OpenRegistration bool `yaml:"openRegistration"`
|
||||
// Do sign up requests require approval from an admin/moderator?
|
||||
RequireApproval bool `yaml:"requireApproval"`
|
||||
// Do we require a reason for a sign up or is an empty string OK?
|
||||
ReasonRequired bool `yaml:"reasonRequired"`
|
||||
}
|
|
@ -33,26 +33,21 @@ type Config struct {
|
|||
Protocol string `yaml:"protocol"`
|
||||
DBConfig *DBConfig `yaml:"db"`
|
||||
TemplateConfig *TemplateConfig `yaml:"template"`
|
||||
AccountsConfig *AccountsConfig `yaml:"accounts"`
|
||||
MediaConfig *MediaConfig `yaml:"media"`
|
||||
StorageConfig *StorageConfig `yaml:"storage"`
|
||||
}
|
||||
|
||||
// FromFile returns a new config from a file, or an error if something goes amiss.
|
||||
func FromFile(path string) (*Config, error) {
|
||||
c, err := loadFromFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating config: %s", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Default returns a new config with default values.
|
||||
// Not yet implemented.
|
||||
func Default() *Config {
|
||||
// TODO: find a way of doing this without code repetition, because having to
|
||||
// repeat all values here and elsewhere is annoying and gonna be prone to mistakes.
|
||||
return &Config{
|
||||
DBConfig: &DBConfig{},
|
||||
TemplateConfig: &TemplateConfig{},
|
||||
if path != "" {
|
||||
c, err := loadFromFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating config: %s", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
return Empty(), nil
|
||||
}
|
||||
|
||||
// Empty just returns an empty config
|
||||
|
@ -60,6 +55,9 @@ func Empty() *Config {
|
|||
return &Config{
|
||||
DBConfig: &DBConfig{},
|
||||
TemplateConfig: &TemplateConfig{},
|
||||
AccountsConfig: &AccountsConfig{},
|
||||
MediaConfig: &MediaConfig{},
|
||||
StorageConfig: &StorageConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,11 +134,51 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) {
|
|||
if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) {
|
||||
c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir)
|
||||
}
|
||||
|
||||
// accounts flags
|
||||
if f.IsSet(fn.AccountsOpenRegistration) {
|
||||
c.AccountsConfig.OpenRegistration = f.Bool(fn.AccountsOpenRegistration)
|
||||
}
|
||||
|
||||
if f.IsSet(fn.AccountsRequireApproval) {
|
||||
c.AccountsConfig.RequireApproval = f.Bool(fn.AccountsRequireApproval)
|
||||
}
|
||||
|
||||
// media flags
|
||||
if c.MediaConfig.MaxImageSize == 0 || f.IsSet(fn.MediaMaxImageSize) {
|
||||
c.MediaConfig.MaxImageSize = f.Int(fn.MediaMaxImageSize)
|
||||
}
|
||||
|
||||
if c.MediaConfig.MaxVideoSize == 0 || f.IsSet(fn.MediaMaxVideoSize) {
|
||||
c.MediaConfig.MaxVideoSize = f.Int(fn.MediaMaxVideoSize)
|
||||
}
|
||||
|
||||
// storage flags
|
||||
if c.StorageConfig.Backend == "" || f.IsSet(fn.StorageBackend) {
|
||||
c.StorageConfig.Backend = f.String(fn.StorageBackend)
|
||||
}
|
||||
|
||||
if c.StorageConfig.BasePath == "" || f.IsSet(fn.StorageBasePath) {
|
||||
c.StorageConfig.BasePath = f.String(fn.StorageBasePath)
|
||||
}
|
||||
|
||||
if c.StorageConfig.ServeProtocol == "" || f.IsSet(fn.StorageServeProtocol) {
|
||||
c.StorageConfig.ServeProtocol = f.String(fn.StorageServeProtocol)
|
||||
}
|
||||
|
||||
if c.StorageConfig.ServeHost == "" || f.IsSet(fn.StorageServeHost) {
|
||||
c.StorageConfig.ServeHost = f.String(fn.StorageServeHost)
|
||||
}
|
||||
|
||||
if c.StorageConfig.ServeBasePath == "" || f.IsSet(fn.StorageServeBasePath) {
|
||||
c.StorageConfig.ServeBasePath = f.String(fn.StorageServeBasePath)
|
||||
}
|
||||
}
|
||||
|
||||
// KeyedFlags is a wrapper for any type that can store keyed flags and give them back.
|
||||
// HINT: This works with a urfave cli context struct ;)
|
||||
type KeyedFlags interface {
|
||||
Bool(k string) bool
|
||||
String(k string) string
|
||||
Int(k string) int
|
||||
IsSet(k string) bool
|
||||
|
@ -154,13 +192,27 @@ type Flags struct {
|
|||
ConfigPath string
|
||||
Host string
|
||||
Protocol string
|
||||
DbType string
|
||||
DbAddress string
|
||||
DbPort string
|
||||
DbUser string
|
||||
DbPassword string
|
||||
DbDatabase string
|
||||
|
||||
DbType string
|
||||
DbAddress string
|
||||
DbPort string
|
||||
DbUser string
|
||||
DbPassword string
|
||||
DbDatabase string
|
||||
|
||||
TemplateBaseDir string
|
||||
|
||||
AccountsOpenRegistration string
|
||||
AccountsRequireApproval string
|
||||
|
||||
MediaMaxImageSize string
|
||||
MediaMaxVideoSize string
|
||||
|
||||
StorageBackend string
|
||||
StorageBasePath string
|
||||
StorageServeProtocol string
|
||||
StorageServeHost string
|
||||
StorageServeBasePath string
|
||||
}
|
||||
|
||||
// GetFlagNames returns a struct containing the names of the various flags used for
|
||||
|
@ -172,13 +224,27 @@ func GetFlagNames() Flags {
|
|||
ConfigPath: "config-path",
|
||||
Host: "host",
|
||||
Protocol: "protocol",
|
||||
DbType: "db-type",
|
||||
DbAddress: "db-address",
|
||||
DbPort: "db-port",
|
||||
DbUser: "db-user",
|
||||
DbPassword: "db-password",
|
||||
DbDatabase: "db-database",
|
||||
|
||||
DbType: "db-type",
|
||||
DbAddress: "db-address",
|
||||
DbPort: "db-port",
|
||||
DbUser: "db-user",
|
||||
DbPassword: "db-password",
|
||||
DbDatabase: "db-database",
|
||||
|
||||
TemplateBaseDir: "template-basedir",
|
||||
|
||||
AccountsOpenRegistration: "accounts-open-registration",
|
||||
AccountsRequireApproval: "accounts-require-approval",
|
||||
|
||||
MediaMaxImageSize: "media-max-image-size",
|
||||
MediaMaxVideoSize: "media-max-video-size",
|
||||
|
||||
StorageBackend: "storage-backend",
|
||||
StorageBasePath: "storage-base-path",
|
||||
StorageServeProtocol: "storage-serve-protocol",
|
||||
StorageServeHost: "storage-serve-host",
|
||||
StorageServeBasePath: "storage-serve-base-path",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,12 +257,26 @@ func GetEnvNames() Flags {
|
|||
ConfigPath: "GTS_CONFIG_PATH",
|
||||
Host: "GTS_HOST",
|
||||
Protocol: "GTS_PROTOCOL",
|
||||
DbType: "GTS_DB_TYPE",
|
||||
DbAddress: "GTS_DB_ADDRESS",
|
||||
DbPort: "GTS_DB_PORT",
|
||||
DbUser: "GTS_DB_USER",
|
||||
DbPassword: "GTS_DB_PASSWORD",
|
||||
DbDatabase: "GTS_DB_DATABASE",
|
||||
|
||||
DbType: "GTS_DB_TYPE",
|
||||
DbAddress: "GTS_DB_ADDRESS",
|
||||
DbPort: "GTS_DB_PORT",
|
||||
DbUser: "GTS_DB_USER",
|
||||
DbPassword: "GTS_DB_PASSWORD",
|
||||
DbDatabase: "GTS_DB_DATABASE",
|
||||
|
||||
TemplateBaseDir: "GTS_TEMPLATE_BASEDIR",
|
||||
|
||||
AccountsOpenRegistration: "GTS_ACCOUNTS_OPEN_REGISTRATION",
|
||||
AccountsRequireApproval: "GTS_ACCOUNTS_REQUIRE_APPROVAL",
|
||||
|
||||
MediaMaxImageSize: "GTS_MEDIA_MAX_IMAGE_SIZE",
|
||||
MediaMaxVideoSize: "GTS_MEDIA_MAX_VIDEO_SIZE",
|
||||
|
||||
StorageBackend: "GTS_STORAGE_BACKEND",
|
||||
StorageBasePath: "GTS_STORAGE_BASE_PATH",
|
||||
StorageServeProtocol: "GTS_STORAGE_SERVE_PROTOCOL",
|
||||
StorageServeHost: "GTS_STORAGE_SERVE_HOST",
|
||||
StorageServeBasePath: "GTS_STORAGE_SERVE_BASE_PATH",
|
||||
}
|
||||
}
|
||||
|
|
27
internal/config/media.go
Normal file
27
internal/config/media.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
// MediaConfig contains configuration for receiving and parsing media files and attachments
|
||||
type MediaConfig struct {
|
||||
// Max size of uploaded images in bytes
|
||||
MaxImageSize int `yaml:"maxImageSize"`
|
||||
// Max size of uploaded video in bytes
|
||||
MaxVideoSize int `yaml:"maxVideoSize"`
|
||||
}
|
66
internal/config/mock_KeyedFlags.go
Normal file
66
internal/config/mock_KeyedFlags.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package config
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockKeyedFlags is an autogenerated mock type for the KeyedFlags type
|
||||
type MockKeyedFlags struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Bool provides a mock function with given fields: k
|
||||
func (_m *MockKeyedFlags) Bool(k string) bool {
|
||||
ret := _m.Called(k)
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(string) bool); ok {
|
||||
r0 = rf(k)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Int provides a mock function with given fields: k
|
||||
func (_m *MockKeyedFlags) Int(k string) int {
|
||||
ret := _m.Called(k)
|
||||
|
||||
var r0 int
|
||||
if rf, ok := ret.Get(0).(func(string) int); ok {
|
||||
r0 = rf(k)
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsSet provides a mock function with given fields: k
|
||||
func (_m *MockKeyedFlags) IsSet(k string) bool {
|
||||
ret := _m.Called(k)
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func(string) bool); ok {
|
||||
r0 = rf(k)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// String provides a mock function with given fields: k
|
||||
func (_m *MockKeyedFlags) String(k string) string {
|
||||
ret := _m.Called(k)
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func(string) string); ok {
|
||||
r0 = rf(k)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
36
internal/config/storage.go
Normal file
36
internal/config/storage.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
// StorageConfig contains configuration for storage and serving of media files and attachments
|
||||
type StorageConfig struct {
|
||||
// Type of storage backend to use: currently only 'local' is supported.
|
||||
// TODO: add S3 support here.
|
||||
Backend string `yaml:"backend"`
|
||||
|
||||
// The base path for storing things. Should be an already-existing directory.
|
||||
BasePath string `yaml:"basePath"`
|
||||
|
||||
// Protocol to use when *serving* media files from storage
|
||||
ServeProtocol string `yaml:"serveProtocol"`
|
||||
// Host to use when *serving* media files from storage
|
||||
ServeHost string `yaml:"serveHost"`
|
||||
// Base path to use when *serving* media files from storage
|
||||
ServeBasePath string `yaml:"serveBasePath"`
|
||||
}
|
|
@ -21,9 +21,9 @@ package db
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/action"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/action"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
// Initialize will initialize the database given in the config for use with GoToSocial
|
||||
|
|
|
@ -21,53 +21,167 @@ package db
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
)
|
||||
|
||||
const dbTypePostgres string = "POSTGRES"
|
||||
|
||||
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
|
||||
type ErrNoEntries struct{}
|
||||
|
||||
func (e ErrNoEntries) Error() string {
|
||||
return "no entries"
|
||||
}
|
||||
|
||||
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
|
||||
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
|
||||
// by whatever is returned from the database.
|
||||
type DB interface {
|
||||
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
|
||||
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
|
||||
Federation() pub.Database
|
||||
|
||||
// CreateTable creates a table for the given interface
|
||||
/*
|
||||
BASIC DB FUNCTIONALITY
|
||||
*/
|
||||
|
||||
// CreateTable creates a table for the given interface.
|
||||
// For implementations that don't use tables, this can just return nil.
|
||||
CreateTable(i interface{}) error
|
||||
|
||||
// DropTable drops the table for the given interface
|
||||
// DropTable drops the table for the given interface.
|
||||
// For implementations that don't use tables, this can just return nil.
|
||||
DropTable(i interface{}) error
|
||||
|
||||
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
|
||||
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
|
||||
// If the database implementation doesn't need to be stopped, this can just return nil.
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// IsHealthy should return nil if the database connection is healthy, or an error if not
|
||||
// IsHealthy should return nil if the database connection is healthy, or an error if not.
|
||||
IsHealthy(ctx context.Context) error
|
||||
|
||||
// GetByID gets one entry by its id.
|
||||
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
|
||||
// for other implementations (for example, in-memory) it might just be the key of a map.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetByID(id string, i interface{}) error
|
||||
|
||||
// GetWhere gets one entry where key = value
|
||||
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
|
||||
// name of the key to select from.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetWhere(key string, value interface{}, i interface{}) error
|
||||
|
||||
// GetAll gets all entries of interface type i
|
||||
// GetAll will try to get all entries of type i.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAll(i interface{}) error
|
||||
|
||||
// Put stores i
|
||||
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
Put(i interface{}) error
|
||||
|
||||
// Update by id updates i with id id
|
||||
// UpdateByID updates i with id id.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
UpdateByID(id string, i interface{}) error
|
||||
|
||||
// Delete by id removes i with id id
|
||||
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
|
||||
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
|
||||
|
||||
// DeleteByID removes i with id id.
|
||||
// If i didn't exist anyway, then no error should be returned.
|
||||
DeleteByID(id string, i interface{}) error
|
||||
|
||||
// Delete where deletes i where key = value
|
||||
// DeleteWhere deletes i where key = value
|
||||
// If i didn't exist anyway, then no error should be returned.
|
||||
DeleteWhere(key string, value interface{}, i interface{}) error
|
||||
|
||||
/*
|
||||
HANDY SHORTCUTS
|
||||
*/
|
||||
|
||||
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
|
||||
// The given account pointer will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountByUserID(userID string, account *model.Account) error
|
||||
|
||||
// GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
|
||||
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error
|
||||
|
||||
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
|
||||
// The given slice 'following' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetFollowingByAccountID(accountID string, following *[]model.Follow) error
|
||||
|
||||
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
|
||||
// The given slice 'followers' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetFollowersByAccountID(accountID string, followers *[]model.Follow) error
|
||||
|
||||
// GetStatusesByAccountID is a shortcut for the common action of fetching a list of statuses produced by accountID.
|
||||
// The given slice 'statuses' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetStatusesByAccountID(accountID string, statuses *[]model.Status) error
|
||||
|
||||
// GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided
|
||||
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
|
||||
// be very memory intensive so you probably shouldn't do this!
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error
|
||||
|
||||
// GetLastStatusForAccountID simply gets the most recent status by the given account.
|
||||
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetLastStatusForAccountID(accountID string, status *model.Status) error
|
||||
|
||||
// IsUsernameAvailable checks whether a given username is available on our domain.
|
||||
// Returns an error if the username is already taken, or something went wrong in the db.
|
||||
IsUsernameAvailable(username string) error
|
||||
|
||||
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
|
||||
// Return an error if:
|
||||
// A) the email is already associated with an account
|
||||
// B) we block signups from this email domain
|
||||
// C) something went wrong in the db
|
||||
IsEmailAvailable(email string) error
|
||||
|
||||
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
|
||||
// By the time this function is called, it should be assumed that all the parameters have passed validation!
|
||||
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
|
||||
|
||||
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
|
||||
SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error
|
||||
|
||||
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
|
||||
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
|
||||
GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error
|
||||
|
||||
// GetHeaderForAccountID gets the current header for the given account ID.
|
||||
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
|
||||
GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error
|
||||
|
||||
/*
|
||||
USEFUL CONVERSION FUNCTIONS
|
||||
*/
|
||||
|
||||
// AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error
|
||||
// if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields,
|
||||
// so serve it only to an authorized user who should have permission to see it.
|
||||
AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error)
|
||||
|
||||
// AccountToMastoPublic takes a db model account as a param, and returns a populated mastotype account, or an error
|
||||
// if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields.
|
||||
// In other words, this is the public record that the server has of an account.
|
||||
AccountToMastoPublic(account *model.Account) (*mastotypes.Account, error)
|
||||
}
|
||||
|
||||
// New returns a new database service that satisfies the DB interface and, by extension,
|
||||
|
|
159
internal/db/federating_db.go
Normal file
159
internal/db/federating_db.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/go-fed/activity/streams"
|
||||
"github.com/go-fed/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
// FederatingDB uses the underlying DB interface to implement the go-fed pub.Database interface.
|
||||
// It doesn't care what the underlying implementation of the DB interface is, as long as it works.
|
||||
type federatingDB struct {
|
||||
locks *sync.Map
|
||||
db DB
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func newFederatingDB(db DB, config *config.Config) pub.Database {
|
||||
return &federatingDB{
|
||||
locks: new(sync.Map),
|
||||
db: db,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
|
||||
*/
|
||||
func (f *federatingDB) Lock(ctx context.Context, id *url.URL) error {
|
||||
// Before any other Database methods are called, the relevant `id`
|
||||
// entries are locked to allow for fine-grained concurrency.
|
||||
|
||||
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
|
||||
// existing mutex.
|
||||
mu := &sync.Mutex{}
|
||||
mu.Lock() // Optimistically lock if we do store it.
|
||||
i, loaded := f.locks.LoadOrStore(id.String(), mu)
|
||||
if loaded {
|
||||
mu = i.(*sync.Mutex)
|
||||
mu.Lock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Unlock(ctx context.Context, id *url.URL) error {
|
||||
// Once Go-Fed is done calling Database methods, the relevant `id`
|
||||
// entries are unlocked.
|
||||
|
||||
i, ok := f.locks.Load(id.String())
|
||||
if !ok {
|
||||
return errors.New("missing an id in unlock")
|
||||
}
|
||||
mu := i.(*sync.Mutex)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
|
||||
return id.Host == f.config.Host, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
|
||||
t, err := streams.NewTypeResolver()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.Resolve(ctx, asType); err != nil {
|
||||
return err
|
||||
}
|
||||
asType.GetTypeName()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *federatingDB) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
21
internal/db/federating_db_test.go
Normal file
21
internal/db/federating_db_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
// TODO: write tests for pgfed
|
363
internal/db/mock_DB.go
Normal file
363
internal/db/mock_DB.go
Normal file
|
@ -0,0 +1,363 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
mastotypes "github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
|
||||
model "github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
|
||||
net "net"
|
||||
|
||||
pub "github.com/go-fed/activity/pub"
|
||||
)
|
||||
|
||||
// MockDB is an autogenerated mock type for the DB type
|
||||
type MockDB struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// AccountToMastoSensitive provides a mock function with given fields: account
|
||||
func (_m *MockDB) AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) {
|
||||
ret := _m.Called(account)
|
||||
|
||||
var r0 *mastotypes.Account
|
||||
if rf, ok := ret.Get(0).(func(*model.Account) *mastotypes.Account); ok {
|
||||
r0 = rf(account)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*mastotypes.Account)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(*model.Account) error); ok {
|
||||
r1 = rf(account)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// CreateTable provides a mock function with given fields: i
|
||||
func (_m *MockDB) CreateTable(i interface{}) error {
|
||||
ret := _m.Called(i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DeleteByID provides a mock function with given fields: id, i
|
||||
func (_m *MockDB) DeleteByID(id string, i interface{}) error {
|
||||
ret := _m.Called(id, i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
|
||||
r0 = rf(id, i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DeleteWhere provides a mock function with given fields: key, value, i
|
||||
func (_m *MockDB) DeleteWhere(key string, value interface{}, i interface{}) error {
|
||||
ret := _m.Called(key, value, i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
|
||||
r0 = rf(key, value, i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DropTable provides a mock function with given fields: i
|
||||
func (_m *MockDB) DropTable(i interface{}) error {
|
||||
ret := _m.Called(i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Federation provides a mock function with given fields:
|
||||
func (_m *MockDB) Federation() pub.Database {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 pub.Database
|
||||
if rf, ok := ret.Get(0).(func() pub.Database); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(pub.Database)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetAccountByUserID provides a mock function with given fields: userID, account
|
||||
func (_m *MockDB) GetAccountByUserID(userID string, account *model.Account) error {
|
||||
ret := _m.Called(userID, account)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *model.Account) error); ok {
|
||||
r0 = rf(userID, account)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetAll provides a mock function with given fields: i
|
||||
func (_m *MockDB) GetAll(i interface{}) error {
|
||||
ret := _m.Called(i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetByID provides a mock function with given fields: id, i
|
||||
func (_m *MockDB) GetByID(id string, i interface{}) error {
|
||||
ret := _m.Called(id, i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
|
||||
r0 = rf(id, i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetFollowRequestsForAccountID provides a mock function with given fields: accountID, followRequests
|
||||
func (_m *MockDB) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
|
||||
ret := _m.Called(accountID, followRequests)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *[]model.FollowRequest) error); ok {
|
||||
r0 = rf(accountID, followRequests)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetFollowersByAccountID provides a mock function with given fields: accountID, followers
|
||||
func (_m *MockDB) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
|
||||
ret := _m.Called(accountID, followers)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *[]model.Follow) error); ok {
|
||||
r0 = rf(accountID, followers)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetFollowingByAccountID provides a mock function with given fields: accountID, following
|
||||
func (_m *MockDB) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
|
||||
ret := _m.Called(accountID, following)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *[]model.Follow) error); ok {
|
||||
r0 = rf(accountID, following)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetLastStatusForAccountID provides a mock function with given fields: accountID, status
|
||||
func (_m *MockDB) GetLastStatusForAccountID(accountID string, status *model.Status) error {
|
||||
ret := _m.Called(accountID, status)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *model.Status) error); ok {
|
||||
r0 = rf(accountID, status)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetStatusesByAccountID provides a mock function with given fields: accountID, statuses
|
||||
func (_m *MockDB) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
|
||||
ret := _m.Called(accountID, statuses)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *[]model.Status) error); ok {
|
||||
r0 = rf(accountID, statuses)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetStatusesByTimeDescending provides a mock function with given fields: accountID, statuses, limit
|
||||
func (_m *MockDB) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
|
||||
ret := _m.Called(accountID, statuses, limit)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, *[]model.Status, int) error); ok {
|
||||
r0 = rf(accountID, statuses, limit)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetWhere provides a mock function with given fields: key, value, i
|
||||
func (_m *MockDB) GetWhere(key string, value interface{}, i interface{}) error {
|
||||
ret := _m.Called(key, value, i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok {
|
||||
r0 = rf(key, value, i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsEmailAvailable provides a mock function with given fields: email
|
||||
func (_m *MockDB) IsEmailAvailable(email string) error {
|
||||
ret := _m.Called(email)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string) error); ok {
|
||||
r0 = rf(email)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsHealthy provides a mock function with given fields: ctx
|
||||
func (_m *MockDB) IsHealthy(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// IsUsernameAvailable provides a mock function with given fields: username
|
||||
func (_m *MockDB) IsUsernameAvailable(username string) error {
|
||||
ret := _m.Called(username)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string) error); ok {
|
||||
r0 = rf(username)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale, appID
|
||||
func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
|
||||
ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale, appID)
|
||||
|
||||
var r0 *model.User
|
||||
if rf, ok := ret.Get(0).(func(string, string, bool, string, string, net.IP, string, string) *model.User); ok {
|
||||
r0 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.User)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string, string, bool, string, string, net.IP, string, string) error); ok {
|
||||
r1 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Put provides a mock function with given fields: i
|
||||
func (_m *MockDB) Put(i interface{}) error {
|
||||
ret := _m.Called(i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields: ctx
|
||||
func (_m *MockDB) Stop(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// UpdateByID provides a mock function with given fields: id, i
|
||||
func (_m *MockDB) UpdateByID(id string, i interface{}) error {
|
||||
ret := _m.Called(id, i)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, interface{}) error); ok {
|
||||
r0 = rf(id, i)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -16,12 +16,14 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
||||
// Package model contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
||||
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
|
||||
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
|
||||
package gtsmodel
|
||||
// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
|
||||
// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
@ -37,20 +39,36 @@ type Account struct {
|
|||
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
|
||||
Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
|
||||
// Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
|
||||
Domain string `pg:",unique:userdomain"` // username and domain
|
||||
Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other
|
||||
|
||||
/*
|
||||
ACCOUNT METADATA
|
||||
*/
|
||||
|
||||
// Avatar image for this account
|
||||
Avatar
|
||||
// Header image for this account
|
||||
Header
|
||||
// File name of the avatar on local storage
|
||||
AvatarFileName string
|
||||
// Gif? png? jpeg?
|
||||
AvatarContentType string
|
||||
// Size of the avatar in bytes
|
||||
AvatarFileSize int
|
||||
// When was the avatar last updated?
|
||||
AvatarUpdatedAt time.Time `pg:"type:timestamp"`
|
||||
// Where can the avatar be retrieved?
|
||||
AvatarRemoteURL *url.URL `pg:"type:text"`
|
||||
// File name of the header on local storage
|
||||
HeaderFileName string
|
||||
// Gif? png? jpeg?
|
||||
HeaderContentType string
|
||||
// Size of the header in bytes
|
||||
HeaderFileSize int
|
||||
// When was the header last updated?
|
||||
HeaderUpdatedAt time.Time `pg:"type:timestamp"`
|
||||
// Where can the header be retrieved?
|
||||
HeaderRemoteURL *url.URL `pg:"type:text"`
|
||||
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
|
||||
DisplayName string
|
||||
// a key/value map of fields that this account has added to their profile
|
||||
Fields map[string]string
|
||||
Fields []Field
|
||||
// A note that this account has on their profile (ie., the account's bio/description of themselves)
|
||||
Note string
|
||||
// Is this a memorial account, ie., has the user passed away?
|
||||
|
@ -63,15 +81,25 @@ type Account struct {
|
|||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When should this account function until
|
||||
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
|
||||
// Does this account identify itself as a bot?
|
||||
Bot bool
|
||||
// What reason was given for signing up when this account was created?
|
||||
Reason string
|
||||
|
||||
/*
|
||||
PRIVACY SETTINGS
|
||||
USER AND PRIVACY PREFERENCES
|
||||
*/
|
||||
|
||||
// Does this account need an approval for new followers?
|
||||
Locked bool
|
||||
// Should this account be shown in the instance's profile directory?
|
||||
Discoverable bool
|
||||
// Default post privacy for this account
|
||||
Privacy string
|
||||
// Set posts from this account to sensitive by default?
|
||||
Sensitive bool
|
||||
// What language does this account post in?
|
||||
Language string
|
||||
|
||||
/*
|
||||
ACTIVITYPUB THINGS
|
||||
|
@ -81,8 +109,6 @@ type Account struct {
|
|||
URI string `pg:",unique"`
|
||||
// At which URL can we see the user account in a web browser?
|
||||
URL string `pg:",unique"`
|
||||
// RemoteURL where this account is located. Will be empty if this is a local account.
|
||||
RemoteURL string `pg:",unique"`
|
||||
// Last time this account was located using the webfinger API.
|
||||
LastWebfingeredAt time.Time `pg:"type:timestamp"`
|
||||
// Address of this account's activitypub inbox, for sending activity to
|
||||
|
@ -106,9 +132,9 @@ type Account struct {
|
|||
|
||||
Secret string
|
||||
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
|
||||
PrivateKey string
|
||||
PrivateKey *rsa.PrivateKey
|
||||
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
|
||||
PublicKey string
|
||||
PublicKey *rsa.PublicKey
|
||||
|
||||
/*
|
||||
ADMIN FIELDS
|
||||
|
@ -128,28 +154,11 @@ type Account struct {
|
|||
SuspensionOrigin int
|
||||
}
|
||||
|
||||
// Avatar represents the avatar for the account for display purposes
|
||||
type Avatar struct {
|
||||
// File name of the avatar on local storage
|
||||
AvatarFileName string
|
||||
// Gif? png? jpeg?
|
||||
AvatarContentType string
|
||||
AvatarFileSize int
|
||||
AvatarUpdatedAt *time.Time `pg:"type:timestamp"`
|
||||
// Where can we retrieve the avatar?
|
||||
AvatarRemoteURL *url.URL `pg:"type:text"`
|
||||
AvatarStorageSchemaVersion int
|
||||
}
|
||||
|
||||
// Header represents the header of the account for display purposes
|
||||
type Header struct {
|
||||
// File name of the header on local storage
|
||||
HeaderFileName string
|
||||
// Gif? png? jpeg?
|
||||
HeaderContentType string
|
||||
HeaderFileSize int
|
||||
HeaderUpdatedAt *time.Time `pg:"type:timestamp"`
|
||||
// Where can we retrieve the header?
|
||||
HeaderRemoteURL *url.URL `pg:"type:text"`
|
||||
HeaderStorageSchemaVersion int
|
||||
// Field represents a key value field on an account, for things like pronouns, website, etc.
|
||||
// VerifiedAt is optional, to be used only if Value is a URL to a webpage that contains the
|
||||
// username of the user.
|
||||
type Field struct {
|
||||
Name string
|
||||
Value string
|
||||
VerifiedAt time.Time `pg:"type:timestamp"`
|
||||
}
|
|
@ -16,9 +16,9 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package gtsmodel
|
||||
package model
|
||||
|
||||
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
import "github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
|
||||
// Application represents an application that can perform actions on behalf of a user.
|
||||
// It is used to authorize tokens etc, and is associated with an oauth client id in the database.
|
||||
|
@ -41,8 +41,8 @@ type Application struct {
|
|||
VapidKey string
|
||||
}
|
||||
|
||||
// ToMastotype returns this application as a mastodon api type, ready for serialization
|
||||
func (a *Application) ToMastotype() *mastotypes.Application {
|
||||
// ToMasto returns this application as a mastodon api type, ready for serialization
|
||||
func (a *Application) ToMasto() *mastotypes.Application {
|
||||
return &mastotypes.Application{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
47
internal/db/model/domainblock.go
Normal file
47
internal/db/model/domainblock.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// DomainBlock represents a federation block against a particular domain, of varying severity.
|
||||
type DomainBlock struct {
|
||||
// ID of this block in the database
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// Domain to block. If ANY PART of the candidate domain contains this string, it will be blocked.
|
||||
// For example: 'example.org' also blocks 'gts.example.org'. '.com' blocks *any* '.com' domains.
|
||||
// TODO: implement wildcards here
|
||||
Domain string `pg:",notnull"`
|
||||
// When was this block created
|
||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When was this block updated
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// Account ID of the creator of this block
|
||||
CreatedByAccountID string `pg:",notnull"`
|
||||
// TODO: define this
|
||||
Severity int
|
||||
// Reject media from this domain?
|
||||
RejectMedia bool
|
||||
// Reject reports from this domain?
|
||||
RejectReports bool
|
||||
// Private comment on this block, viewable to admins
|
||||
PrivateComment string
|
||||
// Public comment on this block, viewable (optionally) by everyone
|
||||
PublicComment string
|
||||
}
|
35
internal/db/model/emaildomainblock.go
Normal file
35
internal/db/model/emaildomainblock.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// EmailDomainBlock represents a domain that the server should automatically reject sign-up requests from.
|
||||
type EmailDomainBlock struct {
|
||||
// ID of this block in the database
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// Email domain to block. Eg. 'gmail.com' or 'hotmail.com'
|
||||
Domain string `pg:",notnull"`
|
||||
// When was this block created
|
||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When was this block updated
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// Account ID of the creator of this block
|
||||
CreatedByAccountID string `pg:",notnull"`
|
||||
}
|
41
internal/db/model/follow.go
Normal file
41
internal/db/model/follow.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Follow represents one account following another, and the metadata around that follow.
|
||||
type Follow struct {
|
||||
// id of this follow in the database
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// When was this follow created?
|
||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When was this follow last updated?
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// Who does this follow belong to?
|
||||
AccountID string `pg:",unique:srctarget,notnull"`
|
||||
// Who does AccountID follow?
|
||||
TargetAccountID string `pg:",unique:srctarget,notnull"`
|
||||
// Does this follow also want to see reblogs and not just posts?
|
||||
ShowReblogs bool `pg:"default:true"`
|
||||
// What is the activitypub URI of this follow?
|
||||
URI string `pg:",unique"`
|
||||
// does the following account want to be notified when the followed account posts?
|
||||
Notify bool
|
||||
}
|
41
internal/db/model/followrequest.go
Normal file
41
internal/db/model/followrequest.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// FollowRequest represents one account requesting to follow another, and the metadata around that request.
|
||||
type FollowRequest struct {
|
||||
// id of this follow request in the database
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// When was this follow request created?
|
||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When was this follow request last updated?
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// Who does this follow request originate from?
|
||||
AccountID string `pg:",unique:srctarget,notnull"`
|
||||
// Who is the target of this follow request?
|
||||
TargetAccountID string `pg:",unique:srctarget,notnull"`
|
||||
// Does this follow also want to see reblogs and not just posts?
|
||||
ShowReblogs bool `pg:"default:true"`
|
||||
// What is the activitypub URI of this follow request?
|
||||
URI string `pg:",unique"`
|
||||
// does the following account want to be notified when the followed account posts?
|
||||
Notify bool
|
||||
}
|
136
internal/db/model/mediaattachment.go
Normal file
136
internal/db/model/mediaattachment.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MediaAttachment represents a user-uploaded media attachment: an image/video/audio/gif that is
|
||||
// somewhere in storage and that can be retrieved and served by the router.
|
||||
type MediaAttachment struct {
|
||||
// ID of the attachment in the database
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// ID of the status to which this is attached
|
||||
StatusID string
|
||||
// Where can the attachment be retrieved on a remote server
|
||||
RemoteURL string
|
||||
// When was the attachment created
|
||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// When was the attachment last updated
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// Type of file (image/gif/audio/video)
|
||||
Type FileType `pg:",notnull"`
|
||||
// Metadata about the file
|
||||
FileMeta FileMeta
|
||||
// To which account does this attachment belong
|
||||
AccountID string `pg:",notnull"`
|
||||
// Description of the attachment (for screenreaders)
|
||||
Description string
|
||||
// To which scheduled status does this attachment belong
|
||||
ScheduledStatusID string
|
||||
// What is the generated blurhash of this attachment
|
||||
Blurhash string
|
||||
// What is the processing status of this attachment
|
||||
Processing ProcessingStatus
|
||||
// metadata for the whole file
|
||||
File File
|
||||
// small image thumbnail derived from a larger image, video, or audio file.
|
||||
Thumbnail Thumbnail
|
||||
// Is this attachment being used as an avatar?
|
||||
Avatar bool
|
||||
// Is this attachment being used as a header?
|
||||
Header bool
|
||||
}
|
||||
|
||||
// File refers to the metadata for the whole file
|
||||
type File struct {
|
||||
// What is the path of the file in storage.
|
||||
Path string
|
||||
// What is the MIME content type of the file.
|
||||
ContentType string
|
||||
// What is the size of the file in bytes.
|
||||
FileSize int
|
||||
// When was the file last updated.
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
}
|
||||
|
||||
// Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file.
|
||||
type Thumbnail struct {
|
||||
// What is the path of the file in storage
|
||||
Path string
|
||||
// What is the MIME content type of the file.
|
||||
ContentType string
|
||||
// What is the size of the file in bytes
|
||||
FileSize int
|
||||
// When was the file last updated
|
||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||
// What is the remote URL of the thumbnail
|
||||
RemoteURL string
|
||||
}
|
||||
|
||||
// ProcessingStatus refers to how far along in the processing stage the attachment is.
|
||||
type ProcessingStatus int
|
||||
|
||||
const (
|
||||
// ProcessingStatusReceived: the attachment has been received and is awaiting processing. No thumbnail available yet.
|
||||
ProcessingStatusReceived ProcessingStatus = 0
|
||||
// ProcessingStatusProcessing: the attachment is currently being processed. Thumbnail is available but full media is not.
|
||||
ProcessingStatusProcessing ProcessingStatus = 1
|
||||
// ProcessingStatusProcessed: the attachment has been fully processed and is ready to be served.
|
||||
ProcessingStatusProcessed ProcessingStatus = 2
|
||||
// ProcessingStatusError: something went wrong processing the attachment and it won't be tried again--these can be deleted.
|
||||
ProcessingStatusError ProcessingStatus = 666
|
||||
)
|
||||
|
||||
// FileType refers to the file type of the media attaachment.
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
// FileTypeImage is for jpegs and pngs
|
||||
FileTypeImage FileType = "image"
|
||||
// FileTypeGif is for native gifs and soundless videos that have been converted to gifs
|
||||
FileTypeGif FileType = "gif"
|
||||
// FileTypeAudio is for audio-only files (no video)
|
||||
FileTypeAudio FileType = "audio"
|
||||
// FileTypeVideo is for files with audio + visual
|
||||
FileTypeVideo FileType = "video"
|
||||
)
|
||||
|
||||
// FileMeta describes metadata about the actual contents of the file.
|
||||
type FileMeta struct {
|
||||
Original Original
|
||||
Small Small
|
||||
}
|
||||
|
||||
// Small implements SmallMeta and can be used for a thumbnail of any media type
|
||||
type Small struct {
|
||||
Width int
|
||||
Height int
|
||||
Size int
|
||||
Aspect float64
|
||||
}
|
||||
|
||||
// ImageOriginal implements OriginalMeta for still images
|
||||
type Original struct {
|
||||
Width int
|
||||
Height int
|
||||
Size int
|
||||
Aspect float64
|
||||
}
|
|
@ -16,7 +16,7 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package gtsmodel
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package gtsmodel
|
||||
package model
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
@ -33,7 +33,7 @@ type User struct {
|
|||
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
|
||||
Email string `pg:",notnull,unique"`
|
||||
Email string `pg:"default:null,unique"`
|
||||
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
|
||||
AccountID string `pg:"default:'',notnull,unique"`
|
||||
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
|
|
@ -1,137 +0,0 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/go-fed/activity/streams"
|
||||
"github.com/go-fed/activity/streams/vocab"
|
||||
"github.com/go-pg/pg/v10"
|
||||
)
|
||||
|
||||
type postgresFederation struct {
|
||||
locks *sync.Map
|
||||
conn *pg.DB
|
||||
}
|
||||
|
||||
func newPostgresFederation(conn *pg.DB) pub.Database {
|
||||
return &postgresFederation{
|
||||
locks: new(sync.Map),
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
|
||||
*/
|
||||
func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error {
|
||||
// Before any other Database methods are called, the relevant `id`
|
||||
// entries are locked to allow for fine-grained concurrency.
|
||||
|
||||
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
|
||||
// existing mutex.
|
||||
mu := &sync.Mutex{}
|
||||
mu.Lock() // Optimistically lock if we do store it.
|
||||
i, loaded := pf.locks.LoadOrStore(id.String(), mu)
|
||||
if loaded {
|
||||
mu = i.(*sync.Mutex)
|
||||
mu.Lock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error {
|
||||
// Once Go-Fed is done calling Database methods, the relevant `id`
|
||||
// entries are unlocked.
|
||||
|
||||
i, ok := pf.locks.Load(id.String())
|
||||
if !ok {
|
||||
return errors.New("missing an id in unlock")
|
||||
}
|
||||
mu := i.(*sync.Mutex)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error {
|
||||
t, err := streams.NewTypeResolver()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.Resolve(ctx, asType); err != nil {
|
||||
return err
|
||||
}
|
||||
asType.GetTypeName()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||
return nil, nil
|
||||
}
|
|
@ -20,8 +20,12 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -30,14 +34,17 @@ import (
|
|||
"github.com/go-pg/pg/extra/pgdebug"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// postgresService satisfies the DB interface
|
||||
type postgresService struct {
|
||||
config *config.DBConfig
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Entry
|
||||
cancel context.CancelFunc
|
||||
|
@ -46,7 +53,7 @@ type postgresService struct {
|
|||
|
||||
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
||||
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
|
||||
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
|
||||
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
|
||||
opts, err := derivePGOptions(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create postgres service: %s", err)
|
||||
|
@ -98,18 +105,18 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
|
|||
return nil, errors.New("db connection timeout")
|
||||
}
|
||||
|
||||
// we can confidently return this useable postgres service now
|
||||
return &postgresService{
|
||||
config: c.DBConfig,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
federationDB: newPostgresFederation(conn),
|
||||
}, nil
|
||||
}
|
||||
ps := &postgresService{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
func (ps *postgresService) Federation() pub.Database {
|
||||
return ps.federationDB
|
||||
federatingDB := newFederatingDB(ps, c)
|
||||
ps.federationDB = federatingDB
|
||||
|
||||
// we can confidently return this useable postgres service now
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -168,9 +175,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
|||
}
|
||||
|
||||
/*
|
||||
EXTRA FUNCTIONS
|
||||
FEDERATION FUNCTIONALITY
|
||||
*/
|
||||
|
||||
func (ps *postgresService) Federation() pub.Database {
|
||||
return ps.federationDB
|
||||
}
|
||||
|
||||
/*
|
||||
BASIC DB FUNCTIONALITY
|
||||
*/
|
||||
|
||||
func (ps *postgresService) CreateTable(i interface{}) error {
|
||||
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) DropTable(i interface{}) error {
|
||||
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||
IfExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) Stop(ctx context.Context) error {
|
||||
ps.log.Info("closing db connection")
|
||||
if err := ps.conn.Close(); err != nil {
|
||||
|
@ -181,11 +208,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
||||
return ps.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
||||
models := []interface{}{
|
||||
(*gtsmodel.Account)(nil),
|
||||
(*gtsmodel.Status)(nil),
|
||||
(*gtsmodel.User)(nil),
|
||||
(*model.Account)(nil),
|
||||
(*model.Status)(nil),
|
||||
(*model.User)(nil),
|
||||
}
|
||||
ps.log.Info("creating db schema")
|
||||
|
||||
|
@ -202,32 +233,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
||||
return ps.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (ps *postgresService) CreateTable(i interface{}) error {
|
||||
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) DropTable(i interface{}) error {
|
||||
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||
IfExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
||||
return ps.conn.Model(i).Where("id = ?", id).Select()
|
||||
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
|
||||
return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
|
||||
if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAll(i interface{}) error {
|
||||
return ps.conn.Model(i).Select()
|
||||
if err := ps.conn.Model(i).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) Put(i interface{}) error {
|
||||
|
@ -236,16 +270,393 @@ func (ps *postgresService) Put(i interface{}) error {
|
|||
}
|
||||
|
||||
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
||||
_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
|
||||
if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
|
||||
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
||||
_, err := ps.conn.Model(i).Where("id = ?", id).Delete()
|
||||
return err
|
||||
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
|
||||
_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
|
||||
if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
HANDY SHORTCUTS
|
||||
*/
|
||||
|
||||
func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
|
||||
user := &model.User{
|
||||
ID: userID,
|
||||
}
|
||||
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
|
||||
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
|
||||
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
|
||||
if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
|
||||
if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
|
||||
q := ps.conn.Model(statuses).Order("created_at DESC")
|
||||
if limit != 0 {
|
||||
q = q.Limit(limit)
|
||||
}
|
||||
if accountID != "" {
|
||||
q = q.Where("account_id = ?", accountID)
|
||||
}
|
||||
if err := q.Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
|
||||
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsUsernameAvailable(username string) error {
|
||||
// if no error we fail because it means we found something
|
||||
// if error but it's not pg.ErrNoRows then we fail
|
||||
// if err is pg.ErrNoRows we're good, we found nothing so continue
|
||||
if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
|
||||
return fmt.Errorf("username %s already in use", username)
|
||||
} else if err != pg.ErrNoRows {
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsEmailAvailable(email string) error {
|
||||
// parse the domain from the email
|
||||
m, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing email address %s: %s", email, err)
|
||||
}
|
||||
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
||||
|
||||
// check if the email domain is blocked
|
||||
if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
|
||||
// fail because we found something
|
||||
return fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != pg.ErrNoRows {
|
||||
// fail because we got an unexpected error
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
// fail because we found something
|
||||
return fmt.Errorf("email %s already in use", email)
|
||||
} else if err != pg.ErrNoRows {
|
||||
// fail because we got an unexpected error
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
ps.log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host)
|
||||
|
||||
a := &model.Account{
|
||||
Username: username,
|
||||
DisplayName: username,
|
||||
Reason: reason,
|
||||
URL: uris.UserURL,
|
||||
PrivateKey: key,
|
||||
PublicKey: &key.PublicKey,
|
||||
ActorType: "Person",
|
||||
URI: uris.UserURI,
|
||||
InboxURL: uris.InboxURL,
|
||||
OutboxURL: uris.OutboxURL,
|
||||
FollowersURL: uris.FollowersURL,
|
||||
FeaturedCollectionURL: uris.CollectionURL,
|
||||
}
|
||||
if _, err = ps.conn.Model(a).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||
}
|
||||
u := &model.User{
|
||||
AccountID: a.ID,
|
||||
EncryptedPassword: string(pw),
|
||||
SignUpIP: signUpIP,
|
||||
Locale: locale,
|
||||
UnconfirmedEmail: email,
|
||||
CreatedByApplicationID: appID,
|
||||
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
|
||||
}
|
||||
if _, err = ps.conn.Model(u).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
|
||||
_, err := ps.conn.Model(mediaAttachment).Insert()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
|
||||
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
|
||||
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
CONVERSION FUNCTIONS
|
||||
*/
|
||||
|
||||
// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
|
||||
// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
|
||||
// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
|
||||
// that the account actually belongs to.
|
||||
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
|
||||
// we can build this sensitive account easily by first getting the public account....
|
||||
mastoAccount, err := ps.AccountToMastoPublic(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// then adding the Source object to it...
|
||||
|
||||
// check pending follow requests aimed at this account
|
||||
fr := []model.FollowRequest{}
|
||||
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting follow requests: %s", err)
|
||||
}
|
||||
}
|
||||
var frc int
|
||||
if fr != nil {
|
||||
frc = len(fr)
|
||||
}
|
||||
|
||||
mastoAccount.Source = &mastotypes.Source{
|
||||
Privacy: a.Privacy,
|
||||
Sensitive: a.Sensitive,
|
||||
Language: a.Language,
|
||||
Note: a.Note,
|
||||
Fields: mastoAccount.Fields,
|
||||
FollowRequestsCount: frc,
|
||||
}
|
||||
|
||||
return mastoAccount, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) {
|
||||
// count followers
|
||||
followers := []model.Follow{}
|
||||
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting followers: %s", err)
|
||||
}
|
||||
}
|
||||
var followersCount int
|
||||
if followers != nil {
|
||||
followersCount = len(followers)
|
||||
}
|
||||
|
||||
// count following
|
||||
following := []model.Follow{}
|
||||
if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting following: %s", err)
|
||||
}
|
||||
}
|
||||
var followingCount int
|
||||
if following != nil {
|
||||
followingCount = len(following)
|
||||
}
|
||||
|
||||
// count statuses
|
||||
statuses := []model.Status{}
|
||||
if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting last statuses: %s", err)
|
||||
}
|
||||
}
|
||||
var statusesCount int
|
||||
if statuses != nil {
|
||||
statusesCount = len(statuses)
|
||||
}
|
||||
|
||||
// check when the last status was
|
||||
lastStatus := &model.Status{}
|
||||
if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting last status: %s", err)
|
||||
}
|
||||
}
|
||||
var lastStatusAt string
|
||||
if lastStatus != nil {
|
||||
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// build the avatar and header URLs
|
||||
avi := &model.MediaAttachment{}
|
||||
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting avatar: %s", err)
|
||||
}
|
||||
}
|
||||
aviURL := avi.File.Path
|
||||
aviURLStatic := avi.Thumbnail.Path
|
||||
|
||||
header := &model.MediaAttachment{}
|
||||
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
|
||||
if _, ok := err.(ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting header: %s", err)
|
||||
}
|
||||
}
|
||||
headerURL := header.File.Path
|
||||
headerURLStatic := header.Thumbnail.Path
|
||||
|
||||
// get the fields set on this account
|
||||
fields := []mastotypes.Field{}
|
||||
for _, f := range a.Fields {
|
||||
mField := mastotypes.Field{
|
||||
Name: f.Name,
|
||||
Value: f.Value,
|
||||
}
|
||||
if !f.VerifiedAt.IsZero() {
|
||||
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
|
||||
}
|
||||
fields = append(fields, mField)
|
||||
}
|
||||
|
||||
var acct string
|
||||
if a.Domain != "" {
|
||||
// this is a remote user
|
||||
acct = fmt.Sprintf("%s@%s", a.Username, a.Domain)
|
||||
} else {
|
||||
// this is a local user
|
||||
acct = a.Username
|
||||
}
|
||||
|
||||
return &mastotypes.Account{
|
||||
ID: a.ID,
|
||||
Username: a.Username,
|
||||
Acct: acct,
|
||||
DisplayName: a.DisplayName,
|
||||
Locked: a.Locked,
|
||||
Bot: a.Bot,
|
||||
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
||||
Note: a.Note,
|
||||
URL: a.URL,
|
||||
Avatar: aviURL,
|
||||
AvatarStatic: aviURLStatic,
|
||||
Header: headerURL,
|
||||
HeaderStatic: headerURLStatic,
|
||||
FollowersCount: followersCount,
|
||||
FollowingCount: followingCount,
|
||||
StatusesCount: statusesCount,
|
||||
LastStatusAt: lastStatusAt,
|
||||
Emojis: nil, // TODO: implement this
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
|
21
internal/db/pg_test.go
Normal file
21
internal/db/pg_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
// TODO: write tests for postgres
|
96
internal/distributor/distributor.go
Normal file
96
internal/distributor/distributor.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package distributor
|
||||
|
||||
import (
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Distributor should be passed to api modules (see internal/apimodule/...). It is used for
|
||||
// passing messages back and forth from the client API and the federating interface, via channels.
|
||||
// It also contains logic for filtering which messages should end up where.
|
||||
// It is designed to be used asynchronously: the client API and the federating API should just be able to
|
||||
// fire messages into the distributor and not wait for a reply before proceeding with other work. This allows
|
||||
// for clean distribution of messages without slowing down the client API and harming the user experience.
|
||||
type Distributor interface {
|
||||
// ClientAPIIn returns a channel for accepting messages that come from the gts client API.
|
||||
ClientAPIIn() chan interface{}
|
||||
// ClientAPIOut returns a channel for putting in messages that need to go to the gts client API.
|
||||
ClientAPIOut() chan interface{}
|
||||
// Start starts the Distributor, reading from its channels and passing messages back and forth.
|
||||
Start() error
|
||||
// Stop stops the distributor cleanly, finishing handling any remaining messages before closing down.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// distributor just implements the Distributor interface
|
||||
type distributor struct {
|
||||
federator pub.FederatingActor
|
||||
clientAPIIn chan interface{}
|
||||
clientAPIOut chan interface{}
|
||||
stop chan interface{}
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new Distributor that uses the given federator and logger
|
||||
func New(federator pub.FederatingActor, log *logrus.Logger) Distributor {
|
||||
return &distributor{
|
||||
federator: federator,
|
||||
clientAPIIn: make(chan interface{}, 100),
|
||||
clientAPIOut: make(chan interface{}, 100),
|
||||
stop: make(chan interface{}),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// ClientAPIIn returns a channel for accepting messages that come from the gts client API.
|
||||
func (d *distributor) ClientAPIIn() chan interface{} {
|
||||
return d.clientAPIIn
|
||||
}
|
||||
|
||||
// ClientAPIOut returns a channel for putting in messages that need to go to the gts client API.
|
||||
func (d *distributor) ClientAPIOut() chan interface{} {
|
||||
return d.clientAPIOut
|
||||
}
|
||||
|
||||
// Start starts the Distributor, reading from its channels and passing messages back and forth.
|
||||
func (d *distributor) Start() error {
|
||||
go func() {
|
||||
DistLoop:
|
||||
for {
|
||||
select {
|
||||
case clientMsgIn := <-d.clientAPIIn:
|
||||
d.log.Infof("received clientMsgIn: %+v", clientMsgIn)
|
||||
case clientMsgOut := <-d.clientAPIOut:
|
||||
d.log.Infof("received clientMsgOut: %+v", clientMsgOut)
|
||||
case <-d.stop:
|
||||
break DistLoop
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the distributor cleanly, finishing handling any remaining messages before closing down.
|
||||
// TODO: empty message buffer properly before stopping otherwise we'll lose federating messages.
|
||||
func (d *distributor) Stop() error {
|
||||
close(d.stop)
|
||||
return nil
|
||||
}
|
70
internal/distributor/mock_Distributor.go
Normal file
70
internal/distributor/mock_Distributor.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package distributor
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockDistributor is an autogenerated mock type for the Distributor type
|
||||
type MockDistributor struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// ClientAPIIn provides a mock function with given fields:
|
||||
func (_m *MockDistributor) ClientAPIIn() chan interface{} {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 chan interface{}
|
||||
if rf, ok := ret.Get(0).(func() chan interface{}); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(chan interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ClientAPIOut provides a mock function with given fields:
|
||||
func (_m *MockDistributor) ClientAPIOut() chan interface{} {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 chan interface{}
|
||||
if rf, ok := ret.Get(0).(func() chan interface{}); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(chan interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields:
|
||||
func (_m *MockDistributor) Start() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields:
|
||||
func (_m *MockDistributor) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -27,88 +27,93 @@ import (
|
|||
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/go-fed/activity/streams/vocab"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
)
|
||||
|
||||
// New returns a go-fed compatible federating actor
|
||||
func New(db db.DB) pub.FederatingActor {
|
||||
fa := &API{}
|
||||
return pub.NewFederatingActor(fa, fa, db.Federation(), fa)
|
||||
func New(db db.DB, log *logrus.Logger) pub.FederatingActor {
|
||||
f := &Federator{
|
||||
db: db,
|
||||
}
|
||||
return pub.NewFederatingActor(f, f, db.Federation(), f)
|
||||
}
|
||||
|
||||
// API implements several go-fed interfaces in one convenient location
|
||||
type API struct {
|
||||
// Federator implements several go-fed interfaces in one convenient location
|
||||
type Federator struct {
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// AuthenticateGetInbox determines whether the request is for a GET call to the Actor's Inbox.
|
||||
func (fa *API) AuthenticateGetInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
func (f *Federator) AuthenticateGetInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
// TODO
|
||||
// use context.WithValue() and context.Value() to set and get values through here
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// AuthenticateGetOutbox determines whether the request is for a GET call to the Actor's Outbox.
|
||||
func (fa *API) AuthenticateGetOutbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
func (f *Federator) AuthenticateGetOutbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
// TODO
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// GetOutbox returns a proper paginated view of the Outbox for serving in a response.
|
||||
func (fa *API) GetOutbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
|
||||
func (f *Federator) GetOutbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// NewTransport returns a new pub.Transport for federating with peer software.
|
||||
func (fa *API) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofedAgent string) (pub.Transport, error) {
|
||||
func (f *Federator) NewTransport(ctx context.Context, actorBoxIRI *url.URL, gofedAgent string) (pub.Transport, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fa *API) PostInboxRequestBodyHook(ctx context.Context, r *http.Request, activity pub.Activity) (context.Context, error) {
|
||||
func (f *Federator) PostInboxRequestBodyHook(ctx context.Context, r *http.Request, activity pub.Activity) (context.Context, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fa *API) AuthenticatePostInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
func (f *Federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool, error) {
|
||||
// TODO
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (fa *API) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
|
||||
func (f *Federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
|
||||
// TODO
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (fa *API) FederatingCallbacks(ctx context.Context) (pub.FederatingWrappedCallbacks, []interface{}, error) {
|
||||
func (f *Federator) FederatingCallbacks(ctx context.Context) (pub.FederatingWrappedCallbacks, []interface{}, error) {
|
||||
// TODO
|
||||
return pub.FederatingWrappedCallbacks{}, nil, nil
|
||||
}
|
||||
|
||||
func (fa *API) DefaultCallback(ctx context.Context, activity pub.Activity) error {
|
||||
func (f *Federator) DefaultCallback(ctx context.Context, activity pub.Activity) error {
|
||||
// TODO
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fa *API) MaxInboxForwardingRecursionDepth(ctx context.Context) int {
|
||||
func (f *Federator) MaxInboxForwardingRecursionDepth(ctx context.Context) int {
|
||||
// TODO
|
||||
return 0
|
||||
}
|
||||
|
||||
func (fa *API) MaxDeliveryRecursionDepth(ctx context.Context) int {
|
||||
func (f *Federator) MaxDeliveryRecursionDepth(ctx context.Context) int {
|
||||
// TODO
|
||||
return 0
|
||||
}
|
||||
|
||||
func (fa *API) FilterForwarding(ctx context.Context, potentialRecipients []*url.URL, a pub.Activity) ([]*url.URL, error) {
|
||||
func (f *Federator) FilterForwarding(ctx context.Context, potentialRecipients []*url.URL, a pub.Activity) ([]*url.URL, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fa *API) GetInbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
|
||||
func (f *Federator) GetInbox(ctx context.Context, r *http.Request) (vocab.ActivityStreamsOrderedCollectionPage, error) {
|
||||
// TODO
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (fa *API) Now() time.Time {
|
||||
func (f *Federator) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
|
|
@ -25,10 +25,20 @@ import (
|
|||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/action"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/action"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule/app"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/apimodule/auth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
)
|
||||
|
||||
// Run creates and starts a gotosocial server
|
||||
|
@ -38,9 +48,48 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr
|
|||
return fmt.Errorf("error creating dbservice: %s", err)
|
||||
}
|
||||
|
||||
// if err := dbService.CreateSchema(ctx); err != nil {
|
||||
// return fmt.Errorf("error creating dbschema: %s", err)
|
||||
// }
|
||||
router, err := router.New(c, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating router: %s", err)
|
||||
}
|
||||
|
||||
storageBackend, err := storage.NewInMem(c, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating storage backend: %s", err)
|
||||
}
|
||||
|
||||
// build backend handlers
|
||||
mediaHandler := media.New(c, dbService, storageBackend, log)
|
||||
oauthServer := oauth.New(dbService, log)
|
||||
|
||||
// build client api modules
|
||||
authModule := auth.New(oauthServer, dbService, log)
|
||||
accountModule := account.New(c, dbService, oauthServer, mediaHandler, log)
|
||||
appsModule := app.New(oauthServer, dbService, log)
|
||||
|
||||
apiModules := []apimodule.ClientAPIModule{
|
||||
authModule, // this one has to go first so the other modules use its middleware
|
||||
accountModule,
|
||||
appsModule,
|
||||
}
|
||||
|
||||
for _, m := range apiModules {
|
||||
if err := m.Route(router); err != nil {
|
||||
return fmt.Errorf("routing error: %s", err)
|
||||
}
|
||||
if err := m.CreateTables(dbService); err != nil {
|
||||
return fmt.Errorf("table creation error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
gts, err := New(dbService, &cache.MockCache{}, router, federation.New(dbService, log), c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating gotosocial service: %s", err)
|
||||
}
|
||||
|
||||
if err := gts.Start(ctx); err != nil {
|
||||
return fmt.Errorf("error starting gotosocial service: %s", err)
|
||||
}
|
||||
|
||||
// catch shutdown signals from the operating system
|
||||
sigs := make(chan os.Signal, 1)
|
||||
|
@ -49,8 +98,8 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr
|
|||
log.Infof("received signal %s, shutting down", sig)
|
||||
|
||||
// close down all running services in order
|
||||
if err := dbService.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("error closing dbservice: %s", err)
|
||||
if err := gts.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("error closing gotosocial service: %s", err)
|
||||
}
|
||||
|
||||
log.Info("done! exiting...")
|
||||
|
|
|
@ -22,17 +22,22 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/go-fed/activity/pub"
|
||||
"github.com/gotosocial/gotosocial/internal/cache"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
)
|
||||
|
||||
// Gotosocial is the 'main' function of the gotosocial server, and the place where everything hangs together.
|
||||
// The logic of stopping and starting the entire server is contained here.
|
||||
type Gotosocial interface {
|
||||
Start(context.Context) error
|
||||
Stop(context.Context) error
|
||||
}
|
||||
|
||||
// New returns a new gotosocial server, initialized with the given configuration.
|
||||
// An error will be returned the caller if something goes wrong during initialization
|
||||
// eg., no db or storage connection, port for router already in use, etc.
|
||||
func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
|
||||
return &gotosocial{
|
||||
db: db,
|
||||
|
@ -43,6 +48,7 @@ func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub
|
|||
}, nil
|
||||
}
|
||||
|
||||
// gotosocial fulfils the gotosocial interface.
|
||||
type gotosocial struct {
|
||||
db db.DB
|
||||
cache cache.Cache
|
||||
|
@ -51,10 +57,19 @@ type gotosocial struct {
|
|||
config *config.Config
|
||||
}
|
||||
|
||||
// Start starts up the gotosocial server. If something goes wrong
|
||||
// while starting the server, then an error will be returned.
|
||||
func (gts *gotosocial) Start(ctx context.Context) error {
|
||||
gts.apiRouter.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gts *gotosocial) Stop(ctx context.Context) error {
|
||||
if err := gts.apiRouter.Stop(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := gts.db.Stop(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
28
internal/gotosocial/mock_Gotosocial.go
Normal file
28
internal/gotosocial/mock_Gotosocial.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package gotosocial
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockGotosocial is an autogenerated mock type for the Gotosocial type
|
||||
type MockGotosocial struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields: _a0
|
||||
func (_m *MockGotosocial) Start(_a0 context.Context) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -18,6 +18,195 @@
|
|||
|
||||
package media
|
||||
|
||||
// API provides an interface for parsing, storing, and retrieving media objects like photos and videos
|
||||
type API interface {
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
)
|
||||
|
||||
// MediaHandler provides an interface for parsing, storing, and retrieving media objects like photos, videos, and gifs.
|
||||
type MediaHandler interface {
|
||||
// SetHeaderOrAvatarForAccountID takes a new header image for an account, checks it out, removes exif data from it,
|
||||
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image,
|
||||
// and then returns information to the caller about the new header.
|
||||
SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error)
|
||||
}
|
||||
|
||||
type mediaHandler struct {
|
||||
config *config.Config
|
||||
db db.DB
|
||||
storage storage.Storage
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
func New(config *config.Config, database db.DB, storage storage.Storage, log *logrus.Logger) MediaHandler {
|
||||
return &mediaHandler{
|
||||
config: config,
|
||||
db: database,
|
||||
storage: storage,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderInfo wraps the urls at which a Header and a StaticHeader is available from the server.
|
||||
type HeaderInfo struct {
|
||||
// URL to the header
|
||||
Header string
|
||||
// Static version of the above (eg., a path to a still image if the header is a gif)
|
||||
HeaderStatic string
|
||||
}
|
||||
|
||||
/*
|
||||
INTERFACE FUNCTIONS
|
||||
*/
|
||||
|
||||
func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
|
||||
l := mh.log.WithField("func", "SetHeaderForAccountID")
|
||||
|
||||
if headerOrAvi != "header" && headerOrAvi != "avatar" {
|
||||
return nil, errors.New("header or avatar not selected")
|
||||
}
|
||||
|
||||
// make sure we have an image we can handle
|
||||
contentType, err := parseContentType(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !supportedImageType(contentType) {
|
||||
return nil, fmt.Errorf("%s is not an accepted image type", contentType)
|
||||
}
|
||||
|
||||
if len(img) == 0 {
|
||||
return nil, fmt.Errorf("passed reader was of size 0")
|
||||
}
|
||||
l.Tracef("read %d bytes of file", len(img))
|
||||
|
||||
// process it
|
||||
ma, err := mh.processHeaderOrAvi(img, contentType, headerOrAvi, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error processing %s: %s", headerOrAvi, err)
|
||||
}
|
||||
|
||||
// set it in the database
|
||||
if err := mh.db.SetHeaderOrAvatarForAccountID(ma, accountID); err != nil {
|
||||
return nil, fmt.Errorf("error putting %s in database: %s", headerOrAvi, err)
|
||||
}
|
||||
|
||||
return ma, nil
|
||||
}
|
||||
|
||||
/*
|
||||
HELPER FUNCTIONS
|
||||
*/
|
||||
|
||||
func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string, headerOrAvi string, accountID string) (*model.MediaAttachment, error) {
|
||||
var isHeader bool
|
||||
var isAvatar bool
|
||||
|
||||
switch headerOrAvi {
|
||||
case "header":
|
||||
isHeader = true
|
||||
case "avatar":
|
||||
isAvatar = true
|
||||
default:
|
||||
return nil, errors.New("header or avatar not selected")
|
||||
}
|
||||
|
||||
var clean []byte
|
||||
var err error
|
||||
|
||||
switch contentType {
|
||||
case "image/jpeg":
|
||||
if clean, err = purgeExif(imageBytes); err != nil {
|
||||
return nil, fmt.Errorf("error cleaning exif data: %s", err)
|
||||
}
|
||||
case "image/png":
|
||||
if clean, err = purgeExif(imageBytes); err != nil {
|
||||
return nil, fmt.Errorf("error cleaning exif data: %s", err)
|
||||
}
|
||||
case "image/gif":
|
||||
clean = imageBytes
|
||||
default:
|
||||
return nil, errors.New("media type unrecognized")
|
||||
}
|
||||
|
||||
original, err := deriveImage(clean, contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing image: %s", err)
|
||||
}
|
||||
|
||||
small, err := deriveThumbnail(clean, contentType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving thumbnail: %s", err)
|
||||
}
|
||||
|
||||
// now put it in storage, take a new uuid for the name of the file so we don't store any unnecessary info about it
|
||||
extension := strings.Split(contentType, "/")[1]
|
||||
newMediaID := uuid.NewString()
|
||||
|
||||
base := fmt.Sprintf("%s://%s%s", mh.config.StorageConfig.ServeProtocol, mh.config.StorageConfig.ServeHost, mh.config.StorageConfig.ServeBasePath)
|
||||
|
||||
// we store the original...
|
||||
originalPath := fmt.Sprintf("%s/%s/%s/original/%s.%s", base, accountID, headerOrAvi, newMediaID, extension)
|
||||
if err := mh.storage.StoreFileAt(originalPath, original.image); err != nil {
|
||||
return nil, fmt.Errorf("storage error: %s", err)
|
||||
}
|
||||
// and a thumbnail...
|
||||
smallPath := fmt.Sprintf("%s/%s/%s/small/%s.%s", base, accountID, headerOrAvi, newMediaID, extension)
|
||||
if err := mh.storage.StoreFileAt(smallPath, small.image); err != nil {
|
||||
return nil, fmt.Errorf("storage error: %s", err)
|
||||
}
|
||||
|
||||
ma := &model.MediaAttachment{
|
||||
ID: newMediaID,
|
||||
StatusID: "",
|
||||
RemoteURL: "",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Type: model.FileTypeImage,
|
||||
FileMeta: model.FileMeta{
|
||||
Original: model.Original{
|
||||
Width: original.width,
|
||||
Height: original.height,
|
||||
Size: original.size,
|
||||
Aspect: original.aspect,
|
||||
},
|
||||
Small: model.Small{
|
||||
Width: small.width,
|
||||
Height: small.height,
|
||||
Size: small.size,
|
||||
Aspect: small.aspect,
|
||||
},
|
||||
},
|
||||
AccountID: accountID,
|
||||
Description: "",
|
||||
ScheduledStatusID: "",
|
||||
Blurhash: original.blurhash,
|
||||
Processing: 2,
|
||||
File: model.File{
|
||||
Path: originalPath,
|
||||
ContentType: contentType,
|
||||
FileSize: len(original.image),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
Thumbnail: model.Thumbnail{
|
||||
Path: smallPath,
|
||||
ContentType: contentType,
|
||||
FileSize: len(small.image),
|
||||
UpdatedAt: time.Now(),
|
||||
RemoteURL: "",
|
||||
},
|
||||
Avatar: isAvatar,
|
||||
Header: isHeader,
|
||||
}
|
||||
|
||||
return ma, nil
|
||||
}
|
||||
|
|
159
internal/media/media_test.go
Normal file
159
internal/media/media_test.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package media
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
)
|
||||
|
||||
type MediaTestSuite struct {
|
||||
suite.Suite
|
||||
config *config.Config
|
||||
log *logrus.Logger
|
||||
db db.DB
|
||||
mediaHandler *mediaHandler
|
||||
mockStorage *storage.MockStorage
|
||||
}
|
||||
|
||||
/*
|
||||
TEST INFRASTRUCTURE
|
||||
*/
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *MediaTestSuite) SetupSuite() {
|
||||
// some of our subsequent entities need a log so create this here
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
suite.log = log
|
||||
|
||||
// Direct config to local postgres instance
|
||||
c := config.Empty()
|
||||
c.Protocol = "http"
|
||||
c.Host = "localhost"
|
||||
c.DBConfig = &config.DBConfig{
|
||||
Type: "postgres",
|
||||
Address: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "postgres",
|
||||
Database: "postgres",
|
||||
ApplicationName: "gotosocial",
|
||||
}
|
||||
c.MediaConfig = &config.MediaConfig{
|
||||
MaxImageSize: 2 << 20,
|
||||
}
|
||||
c.StorageConfig = &config.StorageConfig{
|
||||
Backend: "local",
|
||||
BasePath: "/tmp",
|
||||
ServeProtocol: "http",
|
||||
ServeHost: "localhost",
|
||||
ServeBasePath: "/fileserver/media",
|
||||
}
|
||||
suite.config = c
|
||||
// use an actual database for this, because it's just easier than mocking one out
|
||||
database, err := db.New(context.Background(), c, log)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.db = database
|
||||
|
||||
suite.mockStorage = &storage.MockStorage{}
|
||||
// We don't need storage to do anything for these tests, so just simulate a success and do nothing
|
||||
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
|
||||
|
||||
// and finally here's the thing we're actually testing!
|
||||
suite.mediaHandler = &mediaHandler{
|
||||
config: suite.config,
|
||||
db: suite.db,
|
||||
storage: suite.mockStorage,
|
||||
log: log,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (suite *MediaTestSuite) TearDownSuite() {
|
||||
if err := suite.db.Stop(context.Background()); err != nil {
|
||||
logrus.Panicf("error closing db connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest creates a db connection and creates necessary tables before each test
|
||||
func (suite *MediaTestSuite) SetupTest() {
|
||||
// create all the tables we might need in thie suite
|
||||
models := []interface{}{
|
||||
&model.Account{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.CreateTable(m); err != nil {
|
||||
logrus.Panicf("db connection error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TearDownTest drops tables to make sure there's no data in the db
|
||||
func (suite *MediaTestSuite) TearDownTest() {
|
||||
|
||||
// remove all the tables we might have used so it's clear for the next test
|
||||
models := []interface{}{
|
||||
&model.Account{},
|
||||
&model.MediaAttachment{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
logrus.Panicf("error dropping table: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
ACTUAL TESTS
|
||||
*/
|
||||
|
||||
func (suite *MediaTestSuite) TestSetHeaderOrAvatarForAccountID() {
|
||||
// load test image
|
||||
f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
ma, err := suite.mediaHandler.SetHeaderOrAvatarForAccountID(f, "weeeeeee", "header")
|
||||
assert.Nil(suite.T(), err)
|
||||
suite.log.Debugf("%+v", ma)
|
||||
|
||||
// attachment should have....
|
||||
assert.Equal(suite.T(), "weeeeeee", ma.AccountID)
|
||||
assert.Equal(suite.T(), "LjCZnlvyRkRn_NvzRjWF?urqV@f9", ma.Blurhash)
|
||||
//TODO: add more checks here, cba right now!
|
||||
}
|
||||
|
||||
// TODO: add tests for sad path, gif, png....
|
||||
|
||||
func TestMediaTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaTestSuite))
|
||||
}
|
36
internal/media/mock_MediaHandler.go
Normal file
36
internal/media/mock_MediaHandler.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package media
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
model "github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
)
|
||||
|
||||
// MockMediaHandler is an autogenerated mock type for the MediaHandler type
|
||||
type MockMediaHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// SetHeaderOrAvatarForAccountID provides a mock function with given fields: img, accountID, headerOrAvi
|
||||
func (_m *MockMediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
|
||||
ret := _m.Called(img, accountID, headerOrAvi)
|
||||
|
||||
var r0 *model.MediaAttachment
|
||||
if rf, ok := ret.Get(0).(func([]byte, string, string) *model.MediaAttachment); ok {
|
||||
r0 = rf(img, accountID, headerOrAvi)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.MediaAttachment)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func([]byte, string, string) error); ok {
|
||||
r1 = rf(img, accountID, headerOrAvi)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
1
internal/media/test/test-corrupted.jpg
Normal file
1
internal/media/test/test-corrupted.jpg
Normal file
File diff suppressed because one or more lines are too long
BIN
internal/media/test/test-jpeg-blurhash.jpg
Normal file
BIN
internal/media/test/test-jpeg-blurhash.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.6 KiB |
BIN
internal/media/test/test-jpeg-processed.jpg
Normal file
BIN
internal/media/test/test-jpeg-processed.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 293 KiB |
BIN
internal/media/test/test-jpeg-thumbnail.jpg
Normal file
BIN
internal/media/test/test-jpeg-thumbnail.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.6 KiB |
BIN
internal/media/test/test-jpeg.jpg
Normal file
BIN
internal/media/test/test-jpeg.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 263 KiB |
BIN
internal/media/test/test-with-exif.jpg
Normal file
BIN
internal/media/test/test-with-exif.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
internal/media/test/test-without-exif.jpg
Normal file
BIN
internal/media/test/test-without-exif.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
192
internal/media/util.go
Normal file
192
internal/media/util.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package media
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
|
||||
"github.com/buckket/go-blurhash"
|
||||
"github.com/h2non/filetype"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/superseriousbusiness/exifremove/pkg/exifremove"
|
||||
)
|
||||
|
||||
// parseContentType parses the MIME content type from a file, returning it as a string in the form (eg., "image/jpeg").
|
||||
// Returns an error if the content type is not something we can process.
|
||||
func parseContentType(content []byte) (string, error) {
|
||||
head := make([]byte, 261)
|
||||
_, err := bytes.NewReader(content).Read(head)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not read first magic bytes of file: %s", err)
|
||||
}
|
||||
|
||||
kind, err := filetype.Match(head)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if kind == filetype.Unknown {
|
||||
return "", errors.New("filetype unknown")
|
||||
}
|
||||
|
||||
return kind.MIME.Value, nil
|
||||
}
|
||||
|
||||
// supportedImageType checks mime type of an image against a slice of accepted types,
|
||||
// and returns True if the mime type is accepted.
|
||||
func supportedImageType(mimeType string) bool {
|
||||
acceptedImageTypes := []string{
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/png",
|
||||
}
|
||||
for _, accepted := range acceptedImageTypes {
|
||||
if mimeType == accepted {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// purgeExif is a little wrapper for the action of removing exif data from an image.
|
||||
// Only pass pngs or jpegs to this function.
|
||||
func purgeExif(b []byte) ([]byte, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, errors.New("passed image was not valid")
|
||||
}
|
||||
|
||||
clean, err := exifremove.Remove(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not purge exif from image: %s", err)
|
||||
}
|
||||
if len(clean) == 0 {
|
||||
return nil, errors.New("purged image was not valid")
|
||||
}
|
||||
return clean, nil
|
||||
}
|
||||
|
||||
func deriveImage(b []byte, extension string) (*imageAndMeta, error) {
|
||||
var i image.Image
|
||||
var err error
|
||||
|
||||
switch extension {
|
||||
case "image/jpeg":
|
||||
i, err = jpeg.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "image/png":
|
||||
i, err = png.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "image/gif":
|
||||
i, err = gif.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("extension %s not recognised", extension)
|
||||
}
|
||||
|
||||
width := i.Bounds().Size().X
|
||||
height := i.Bounds().Size().Y
|
||||
size := width * height
|
||||
aspect := float64(width) / float64(height)
|
||||
bh, err := blurhash.Encode(4, 3, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating blurhash: %s", err)
|
||||
}
|
||||
|
||||
out := &bytes.Buffer{}
|
||||
if err := jpeg.Encode(out, i, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &imageAndMeta{
|
||||
image: out.Bytes(),
|
||||
width: width,
|
||||
height: height,
|
||||
size: size,
|
||||
aspect: aspect,
|
||||
blurhash: bh,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// deriveThumbnailFromImage returns a byte slice and metadata for a 256-pixel-width thumbnail
|
||||
// of a given jpeg, png, or gif, or an error if something goes wrong.
|
||||
//
|
||||
// Note that the aspect ratio of the image will be retained,
|
||||
// so it will not necessarily be a square.
|
||||
func deriveThumbnail(b []byte, extension string) (*imageAndMeta, error) {
|
||||
var i image.Image
|
||||
var err error
|
||||
|
||||
switch extension {
|
||||
case "image/jpeg":
|
||||
i, err = jpeg.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "image/png":
|
||||
i, err = png.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "image/gif":
|
||||
i, err = gif.Decode(bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("extension %s not recognised", extension)
|
||||
}
|
||||
|
||||
thumb := resize.Thumbnail(256, 256, i, resize.NearestNeighbor)
|
||||
width := thumb.Bounds().Size().X
|
||||
height := thumb.Bounds().Size().Y
|
||||
size := width * height
|
||||
aspect := float64(width) / float64(height)
|
||||
|
||||
out := &bytes.Buffer{}
|
||||
if err := jpeg.Encode(out, thumb, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &imageAndMeta{
|
||||
image: out.Bytes(),
|
||||
width: width,
|
||||
height: height,
|
||||
size: size,
|
||||
aspect: aspect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type imageAndMeta struct {
|
||||
image []byte
|
||||
width int
|
||||
height int
|
||||
size int
|
||||
aspect float64
|
||||
blurhash string
|
||||
}
|
147
internal/media/util_test.go
Normal file
147
internal/media/util_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package media
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type MediaUtilTestSuite struct {
|
||||
suite.Suite
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
/*
|
||||
TEST INFRASTRUCTURE
|
||||
*/
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *MediaUtilTestSuite) SetupSuite() {
|
||||
// some of our subsequent entities need a log so create this here
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
suite.log = log
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TearDownSuite() {
|
||||
|
||||
}
|
||||
|
||||
// SetupTest creates a db connection and creates necessary tables before each test
|
||||
func (suite *MediaUtilTestSuite) SetupTest() {
|
||||
|
||||
}
|
||||
|
||||
// TearDownTest drops tables to make sure there's no data in the db
|
||||
func (suite *MediaUtilTestSuite) TearDownTest() {
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
ACTUAL TESTS
|
||||
*/
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestParseContentTypeOK() {
|
||||
f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
ct, err := parseContentType(f)
|
||||
assert.Nil(suite.T(), err)
|
||||
assert.Equal(suite.T(), "image/jpeg", ct)
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestParseContentTypeNotOK() {
|
||||
f, err := ioutil.ReadFile("./test/test-corrupted.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
ct, err := parseContentType(f)
|
||||
assert.NotNil(suite.T(), err)
|
||||
assert.Equal(suite.T(), "", ct)
|
||||
assert.Equal(suite.T(), "filetype unknown", err.Error())
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestRemoveEXIF() {
|
||||
// load and validate image
|
||||
b, err := ioutil.ReadFile("./test/test-with-exif.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
// clean it up and validate the clean version
|
||||
clean, err := purgeExif(b)
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
// compare it to our stored sample
|
||||
sampleBytes, err := ioutil.ReadFile("./test/test-without-exif.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
assert.EqualValues(suite.T(), sampleBytes, clean)
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestDeriveImageFromJPEG() {
|
||||
// load image
|
||||
b, err := ioutil.ReadFile("./test/test-jpeg.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
// clean it up and validate the clean version
|
||||
imageAndMeta, err := deriveImage(b, "image/jpeg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
assert.Equal(suite.T(), 1920, imageAndMeta.width)
|
||||
assert.Equal(suite.T(), 1080, imageAndMeta.height)
|
||||
assert.Equal(suite.T(), 1.7777777777777777, imageAndMeta.aspect)
|
||||
assert.Equal(suite.T(), 2073600, imageAndMeta.size)
|
||||
assert.Equal(suite.T(), "LjCZnlvyRkRn_NvzRjWF?urqV@f9", imageAndMeta.blurhash)
|
||||
|
||||
// assert that the final image is what we would expect
|
||||
sampleBytes, err := ioutil.ReadFile("./test/test-jpeg-processed.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image)
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestDeriveThumbnailFromJPEG() {
|
||||
// load image
|
||||
b, err := ioutil.ReadFile("./test/test-jpeg.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
// clean it up and validate the clean version
|
||||
imageAndMeta, err := deriveThumbnail(b, "image/jpeg")
|
||||
assert.Nil(suite.T(), err)
|
||||
|
||||
assert.Equal(suite.T(), 256, imageAndMeta.width)
|
||||
assert.Equal(suite.T(), 144, imageAndMeta.height)
|
||||
assert.Equal(suite.T(), 1.7777777777777777, imageAndMeta.aspect)
|
||||
assert.Equal(suite.T(), 36864, imageAndMeta.size)
|
||||
|
||||
sampleBytes, err := ioutil.ReadFile("./test/test-jpeg-thumbnail.jpg")
|
||||
assert.Nil(suite.T(), err)
|
||||
assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image)
|
||||
}
|
||||
|
||||
func (suite *MediaUtilTestSuite) TestSupportedImageTypes() {
|
||||
ok := supportedImageType("image/jpeg")
|
||||
assert.True(suite.T(), ok)
|
||||
|
||||
ok = supportedImageType("image/bmp")
|
||||
assert.False(suite.T(), ok)
|
||||
}
|
||||
|
||||
func TestMediaUtilTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MediaUtilTestSuite))
|
||||
}
|
|
@ -1,510 +0,0 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Package oauth is a module that provides oauth functionality to a router.
|
||||
// It adds the following paths:
|
||||
// /api/v1/apps
|
||||
// /auth/sign_in
|
||||
// /oauth/token
|
||||
// /oauth/authorize
|
||||
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||
"github.com/gotosocial/gotosocial/internal/module"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/gotosocial/oauth2/v4/errors"
|
||||
"github.com/gotosocial/oauth2/v4/manage"
|
||||
"github.com/gotosocial/oauth2/v4/server"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
appsPath = "/api/v1/apps"
|
||||
authSignInPath = "/auth/sign_in"
|
||||
oauthTokenPath = "/oauth/token"
|
||||
oauthAuthorizePath = "/oauth/authorize"
|
||||
)
|
||||
|
||||
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
|
||||
type oauthModule struct {
|
||||
oauthManager *manage.Manager
|
||||
oauthServer *server.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
type login struct {
|
||||
Email string `form:"username"`
|
||||
Password string `form:"password"`
|
||||
}
|
||||
|
||||
// New returns a new oauth module
|
||||
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||
log.Errorf("internal oauth error: %s", err)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||
log.Errorf("internal response error: %s", re.Error)
|
||||
})
|
||||
|
||||
m := &oauthModule{
|
||||
oauthManager: manager,
|
||||
oauthServer: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
|
||||
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
|
||||
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return m
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *oauthModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||
|
||||
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
||||
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
||||
|
||||
s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
|
||||
|
||||
s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
|
||||
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
||||
|
||||
s.AttachMiddleware(m.oauthTokenMiddleware)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
MAIN HANDLERS -- serve these through a server/router
|
||||
*/
|
||||
|
||||
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||
l.Trace("entering AppsPOSTHandler")
|
||||
|
||||
form := &mastotypes.ApplicationPOSTRequest{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// permitted length for most fields
|
||||
permittedLength := 64
|
||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||
permittedRedirect := 256
|
||||
|
||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||
if len(form.ClientName) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.Website) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.RedirectURIs) > permittedRedirect {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||
return
|
||||
}
|
||||
if len(form.Scopes) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
|
||||
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||
var scopes string
|
||||
if form.Scopes == "" {
|
||||
scopes = "read"
|
||||
} else {
|
||||
scopes = form.Scopes
|
||||
}
|
||||
|
||||
// generate new IDs for this application and its associated client
|
||||
clientID := uuid.NewString()
|
||||
clientSecret := uuid.NewString()
|
||||
vapidKey := uuid.NewString()
|
||||
|
||||
// generate the application to put in the database
|
||||
app := >smodel.Application{
|
||||
Name: form.ClientName,
|
||||
Website: form.Website,
|
||||
RedirectURI: form.RedirectURIs,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: scopes,
|
||||
VapidKey: vapidKey,
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// now we need to model an oauth client from the application that the oauth library can use
|
||||
oc := &oauthClient{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Domain: form.RedirectURIs,
|
||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(oc); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||
c.JSON(http.StatusOK, app.ToMastotype())
|
||||
}
|
||||
|
||||
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
||||
func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
||||
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
||||
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
||||
}
|
||||
|
||||
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The handler will then redirect to the auth handler served at /auth
|
||||
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "SignInPOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
form := &login{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
l.Tracef("parsed form: %+v", form)
|
||||
|
||||
userid, err := m.validatePassword(form.Email, form.Password)
|
||||
if err != nil {
|
||||
c.String(http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.Set("userid", userid)
|
||||
if err := s.Save(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
l.Trace("redirecting to auth page")
|
||||
c.Redirect(http.StatusFound, oauthAuthorizePath)
|
||||
}
|
||||
|
||||
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
||||
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
||||
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
||||
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "TokenPOSTHandler")
|
||||
l.Trace("entered TokenPOSTHandler")
|
||||
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
||||
// The idea here is to present an oauth authorize page to the user, with a button
|
||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizeGETHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
|
||||
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
|
||||
userID, ok := s.Get("userid").(string)
|
||||
if !ok || userID == "" {
|
||||
l.Trace("userid was empty, parsing form then redirecting to sign in page")
|
||||
if err := parseAuthForm(c, l); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, authSignInPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// We can use the client_id on the session to retrieve info about the app associated with the client_id
|
||||
clientID, ok := s.Get("client_id").(string)
|
||||
if !ok || clientID == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
||||
return
|
||||
}
|
||||
app := >smodel.Application{
|
||||
ClientID: clientID,
|
||||
}
|
||||
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||
return
|
||||
}
|
||||
|
||||
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
||||
user := >smodel.User{
|
||||
ID: userID,
|
||||
}
|
||||
if err := m.db.GetByID(user.ID, user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
acct := >smodel.Account{
|
||||
ID: user.AccountID,
|
||||
}
|
||||
|
||||
if err := m.db.GetByID(acct.ID, acct); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Finally we should also get the redirect and scope of this particular request, as stored in the session.
|
||||
redirect, ok := s.Get("redirect_uri").(string)
|
||||
if !ok || redirect == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
|
||||
return
|
||||
}
|
||||
scope, ok := s.Get("scope").(string)
|
||||
if !ok || scope == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
|
||||
return
|
||||
}
|
||||
|
||||
// the authorize template will display a form to the user where they can get some information
|
||||
// about the app that's trying to authorize, and the scope of the request.
|
||||
// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
|
||||
l.Trace("serving authorize html")
|
||||
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
|
||||
"appname": app.Name,
|
||||
"appwebsite": app.Website,
|
||||
"redirect": redirect,
|
||||
"scope": scope,
|
||||
"user": acct.Username,
|
||||
})
|
||||
}
|
||||
|
||||
// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
|
||||
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
||||
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
||||
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
// At this point we know the user has said 'yes' to allowing the application and oauth client
|
||||
// work for them, so we can set the
|
||||
|
||||
// We need to retrieve the original form submitted to the authorizeGEThandler, and
|
||||
// recreate it on the request so that it can be used further by the oauth2 library.
|
||||
// So first fetch all the values from the session.
|
||||
forceLogin, ok := s.Get("force_login").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
|
||||
return
|
||||
}
|
||||
responseType, ok := s.Get("response_type").(string)
|
||||
if !ok || responseType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
|
||||
return
|
||||
}
|
||||
clientID, ok := s.Get("client_id").(string)
|
||||
if !ok || clientID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
|
||||
return
|
||||
}
|
||||
redirectURI, ok := s.Get("redirect_uri").(string)
|
||||
if !ok || redirectURI == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
|
||||
return
|
||||
}
|
||||
scope, ok := s.Get("scope").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
|
||||
return
|
||||
}
|
||||
userID, ok := s.Get("userid").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
|
||||
return
|
||||
}
|
||||
// we're done with the session so we can clear it now
|
||||
s.Clear()
|
||||
|
||||
// now set the values on the request
|
||||
values := url.Values{}
|
||||
values.Set("force_login", forceLogin)
|
||||
values.Set("response_type", responseType)
|
||||
values.Set("client_id", clientID)
|
||||
values.Set("redirect_uri", redirectURI)
|
||||
values.Set("scope", scope)
|
||||
values.Set("userid", userID)
|
||||
c.Request.Form = values
|
||||
l.Tracef("values on request set to %+v", c.Request.Form)
|
||||
|
||||
// and proceed with authorization using the oauth2 library
|
||||
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
MIDDLEWARE
|
||||
*/
|
||||
|
||||
// oauthTokenMiddleware
|
||||
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
l.Trace("entering OauthTokenMiddleware")
|
||||
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
|
||||
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
||||
c.Set("authenticated_user", ti.GetUserID())
|
||||
|
||||
} else {
|
||||
l.Trace("continuing with unauthenticated request")
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs
|
||||
*/
|
||||
|
||||
// validatePassword takes an email address and a password.
|
||||
// The goal is to authenticate the password against the one for that email
|
||||
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
|
||||
// make sure an email/password was provided and bail if not
|
||||
if email == "" || password == "" {
|
||||
l.Debug("email or password was not provided")
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// first we select the user from the database based on email address, bail if no user found for that email
|
||||
gtsUser := >smodel.User{}
|
||||
|
||||
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
|
||||
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// make sure a password is actually set and bail if not
|
||||
if gtsUser.EncryptedPassword == "" {
|
||||
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// compare the provided password with the encrypted one from the db, bail if they don't match
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
|
||||
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
||||
// If we've made it this far the email/password is correct, so we can just return the id of the user.
|
||||
userid = gtsUser.ID
|
||||
l.Tracef("returning (%s, %s)", userid, err)
|
||||
return
|
||||
}
|
||||
|
||||
// incorrectPassword is just a little helper function to use in the ValidatePassword function
|
||||
func incorrectPassword() (string, error) {
|
||||
return "", errors.New("password/email combination was incorrect")
|
||||
}
|
||||
|
||||
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
|
||||
// or redirects to the /auth/sign_in page, if this key is not present.
|
||||
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
|
||||
l := m.log.WithField("func", "UserAuthorizationHandler")
|
||||
userID = r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty, redirecting to sign in page")
|
||||
}
|
||||
l.Tracef("returning userID %s", userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
||||
// the values in the form into the session.
|
||||
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
||||
s := sessions.Default(c)
|
||||
|
||||
// first make sure they've filled out the authorize form with the required values
|
||||
form := &mastotypes.OAuthAuthorize{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Tracef("parsed form: %+v", form)
|
||||
|
||||
// these fields are *required* so check 'em
|
||||
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
|
||||
return errors.New("missing one of: response_type, client_id or redirect_uri")
|
||||
}
|
||||
|
||||
// set default scope to read
|
||||
if form.Scope == "" {
|
||||
form.Scope = "read"
|
||||
}
|
||||
|
||||
// save these values from the form so we can use them elsewhere in the session
|
||||
s.Set("force_login", form.ForceLogin)
|
||||
s.Set("response_type", form.ResponseType)
|
||||
s.Set("client_id", form.ClientID)
|
||||
s.Set("redirect_uri", form.RedirectURI)
|
||||
s.Set("scope", form.Scope)
|
||||
return s.Save()
|
||||
}
|
|
@ -20,11 +20,10 @@ package oauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/gotosocial/oauth2/v4/models"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
)
|
||||
|
||||
type clientStore struct {
|
||||
|
@ -39,17 +38,17 @@ func newClientStore(db db.DB) oauth2.ClientStore {
|
|||
}
|
||||
|
||||
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: clientID,
|
||||
}
|
||||
if err := cs.db.GetByID(clientID, poc); err != nil {
|
||||
return nil, fmt.Errorf("database error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
|
||||
}
|
||||
|
||||
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: cli.GetID(),
|
||||
Secret: cli.GetSecret(),
|
||||
Domain: cli.GetDomain(),
|
||||
|
@ -59,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
|
|||
}
|
||||
|
||||
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: id,
|
||||
}
|
||||
return cs.db.DeleteByID(id, poc)
|
||||
}
|
||||
|
||||
type oauthClient struct {
|
||||
type Client struct {
|
||||
ID string
|
||||
Secret string
|
||||
Domain string
|
|
@ -21,11 +21,11 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/oauth2/v4/models"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
)
|
||||
|
||||
type PgClientStoreTestSuite struct {
|
||||
|
@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
|||
suite.db = db
|
||||
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&Client{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
|
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
|||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&Client{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
|
@ -136,7 +136,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
|||
// try to get the deleted client; we should get an error
|
||||
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||
suite.Assert().Nil(deletedClient)
|
||||
suite.Assert().NotNil(err)
|
||||
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
|
||||
}
|
||||
|
||||
func TestPgClientStoreTestSuite(t *testing.T) {
|
89
internal/oauth/mock_Server.go
Normal file
89
internal/oauth/mock_Server.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
http "net/http"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
oauth2 "github.com/superseriousbusiness/oauth2/v4"
|
||||
)
|
||||
|
||||
// MockServer is an autogenerated mock type for the Server type
|
||||
type MockServer struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// GenerateUserAccessToken provides a mock function with given fields: ti, clientSecret, userID
|
||||
func (_m *MockServer) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (oauth2.TokenInfo, error) {
|
||||
ret := _m.Called(ti, clientSecret, userID)
|
||||
|
||||
var r0 oauth2.TokenInfo
|
||||
if rf, ok := ret.Get(0).(func(oauth2.TokenInfo, string, string) oauth2.TokenInfo); ok {
|
||||
r0 = rf(ti, clientSecret, userID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(oauth2.TokenInfo)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(oauth2.TokenInfo, string, string) error); ok {
|
||||
r1 = rf(ti, clientSecret, userID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// HandleAuthorizeRequest provides a mock function with given fields: w, r
|
||||
func (_m *MockServer) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
ret := _m.Called(w, r)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(http.ResponseWriter, *http.Request) error); ok {
|
||||
r0 = rf(w, r)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// HandleTokenRequest provides a mock function with given fields: w, r
|
||||
func (_m *MockServer) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
ret := _m.Called(w, r)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(http.ResponseWriter, *http.Request) error); ok {
|
||||
r0 = rf(w, r)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ValidationBearerToken provides a mock function with given fields: r
|
||||
func (_m *MockServer) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
|
||||
ret := _m.Called(r)
|
||||
|
||||
var r0 oauth2.TokenInfo
|
||||
if rf, ok := ret.Get(0).(func(*http.Request) oauth2.TokenInfo); ok {
|
||||
r0 = rf(r)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(oauth2.TokenInfo)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(*http.Request) error); ok {
|
||||
r1 = rf(r)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
21
internal/oauth/oauth_test.go
Normal file
21
internal/oauth/oauth_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
// TODO: write tests
|
254
internal/oauth/server.go
Normal file
254
internal/oauth/server.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/errors"
|
||||
"github.com/superseriousbusiness/oauth2/v4/manage"
|
||||
"github.com/superseriousbusiness/oauth2/v4/server"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionAuthorizedToken = "authorized_token"
|
||||
// SessionAuthorizedUser is the key set in the gin context for the id of
|
||||
// a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
|
||||
SessionAuthorizedUser = "authorized_user"
|
||||
// SessionAuthorizedAccount is the key set in the gin context for the Account
|
||||
// of a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
|
||||
SessionAuthorizedAccount = "authorized_account"
|
||||
// SessionAuthorizedAccount is the key set in the gin context for the Application
|
||||
// of a Client who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
|
||||
SessionAuthorizedApplication = "authorized_app"
|
||||
)
|
||||
|
||||
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
|
||||
type Server interface {
|
||||
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
|
||||
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
|
||||
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
|
||||
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
|
||||
}
|
||||
|
||||
// s fulfils the Server interface using the underlying oauth2 server
|
||||
type s struct {
|
||||
server *server.Server
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
type Authed struct {
|
||||
Token oauth2.TokenInfo
|
||||
Application *model.Application
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
}
|
||||
|
||||
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
|
||||
// In essence, it tries to extract a token, application, user, and account from the context,
|
||||
// and then sets them on a struct for convenience.
|
||||
//
|
||||
// If any are not present in the context, they will be set to nil on the returned Authed struct.
|
||||
//
|
||||
// If *ALL* are not present, then nil and an error will be returned.
|
||||
//
|
||||
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
|
||||
func GetAuthed(c *gin.Context) (*Authed, error) {
|
||||
ctx := c.Copy()
|
||||
a := &Authed{}
|
||||
var i interface{}
|
||||
var ok bool
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedToken)
|
||||
if ok {
|
||||
parsed, ok := i.(oauth2.TokenInfo)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse token from session context")
|
||||
}
|
||||
a.Token = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedApplication)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.Application)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse application from session context")
|
||||
}
|
||||
a.Application = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedUser)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.User)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse user from session context")
|
||||
}
|
||||
a.User = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedAccount)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.Account)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse account from session context")
|
||||
}
|
||||
a.Account = parsed
|
||||
}
|
||||
|
||||
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
|
||||
return nil, errors.New("not authorized")
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// MustAuth is like GetAuthed, but will fail if one of the requirements is not met.
|
||||
func MustAuth(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Authed, error) {
|
||||
a, err := GetAuthed(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if requireToken && a.Token == nil {
|
||||
return nil, errors.New("token not supplied")
|
||||
}
|
||||
if requireApp && a.Application == nil {
|
||||
return nil, errors.New("application not supplied")
|
||||
}
|
||||
if requireUser && a.User == nil {
|
||||
return nil, errors.New("user not supplied")
|
||||
}
|
||||
if requireAccount && a.Account == nil {
|
||||
return nil, errors.New("account not supplied")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.server.HandleTokenRequest(w, r)
|
||||
}
|
||||
|
||||
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
|
||||
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.server.HandleAuthorizeRequest(w, r)
|
||||
}
|
||||
|
||||
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
|
||||
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
|
||||
return s.server.ValidationBearerToken(r)
|
||||
}
|
||||
|
||||
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
|
||||
// bearer token *without* requiring that user to log in. This is useful when we
|
||||
// need to create a token for new users who haven't validated their email or logged in yet.
|
||||
//
|
||||
// The ti parameter refers to an existing Application token that was used to make the upstream
|
||||
// request. This token needs to be validated and exist in database in order to create a new token.
|
||||
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (oauth2.TokenInfo, error) {
|
||||
|
||||
authToken, err := s.server.Manager.GenerateAuthToken(context.Background(), oauth2.Code, &oauth2.TokenGenerateRequest{
|
||||
ClientID: ti.GetClientID(),
|
||||
ClientSecret: clientSecret,
|
||||
UserID: userID,
|
||||
RedirectURI: ti.GetRedirectURI(),
|
||||
Scope: ti.GetScope(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating auth token: %s", err)
|
||||
}
|
||||
if authToken == nil {
|
||||
return nil, errors.New("generated auth token was empty")
|
||||
}
|
||||
s.log.Tracef("obtained auth token: %+v", authToken)
|
||||
|
||||
accessToken, err := s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, &oauth2.TokenGenerateRequest{
|
||||
ClientID: authToken.GetClientID(),
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURI: authToken.GetRedirectURI(),
|
||||
Scope: authToken.GetScope(),
|
||||
Code: authToken.GetCode(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating user-level access token: %s", err)
|
||||
}
|
||||
if accessToken == nil {
|
||||
return nil, errors.New("generated user-level access token was empty")
|
||||
}
|
||||
s.log.Tracef("obtained user-level access token: %+v", accessToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func New(database db.DB, log *logrus.Logger) Server {
|
||||
ts := newTokenStore(context.Background(), database, log)
|
||||
cs := newClientStore(database)
|
||||
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
// - Client Credentials (for applications)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
oauth2.ClientCredentials,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||
log.Errorf("internal oauth error: %s", err)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||
log.Errorf("internal response error: %s", re.Error)
|
||||
})
|
||||
|
||||
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
userID := r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty")
|
||||
}
|
||||
return userID, nil
|
||||
})
|
||||
srv.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return &s{
|
||||
server: srv,
|
||||
log: log,
|
||||
}
|
||||
}
|
|
@ -24,10 +24,10 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/gotosocial/oauth2/v4/models"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
)
|
||||
|
||||
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
|
||||
|
@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
|
|||
func (pts *tokenStore) sweep() error {
|
||||
// select *all* tokens from the db
|
||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||
tokens := new([]*oauthToken)
|
||||
tokens := new([]*Token)
|
||||
if err := pts.db.GetAll(tokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ func (pts *tokenStore) sweep() error {
|
|||
}
|
||||
|
||||
// Create creates and store the new token information.
|
||||
// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
|
||||
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
|
||||
func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
||||
t, ok := info.(*models.Token)
|
||||
if !ok {
|
||||
|
@ -106,22 +106,25 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
|
|||
|
||||
// RemoveByCode deletes a token from the DB based on the Code field
|
||||
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
||||
return pts.db.DeleteWhere("code", code, &oauthToken{})
|
||||
return pts.db.DeleteWhere("code", code, &Token{})
|
||||
}
|
||||
|
||||
// RemoveByAccess deletes a token from the DB based on the Access field
|
||||
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
||||
return pts.db.DeleteWhere("access", access, &oauthToken{})
|
||||
return pts.db.DeleteWhere("access", access, &Token{})
|
||||
}
|
||||
|
||||
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
||||
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
||||
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
|
||||
return pts.db.DeleteWhere("refresh", refresh, &Token{})
|
||||
}
|
||||
|
||||
// GetByCode selects a token from the DB based on the Code field
|
||||
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
if code == "" {
|
||||
return nil, nil
|
||||
}
|
||||
pgt := &Token{
|
||||
Code: code,
|
||||
}
|
||||
if err := pts.db.GetWhere("code", code, pgt); err != nil {
|
||||
|
@ -132,7 +135,10 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
|
|||
|
||||
// GetByAccess selects a token from the DB based on the Access field
|
||||
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
if access == "" {
|
||||
return nil, nil
|
||||
}
|
||||
pgt := &Token{
|
||||
Access: access,
|
||||
}
|
||||
if err := pts.db.GetWhere("access", access, pgt); err != nil {
|
||||
|
@ -143,7 +149,10 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
|
|||
|
||||
// GetByRefresh selects a token from the DB based on the Refresh field
|
||||
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
if refresh == "" {
|
||||
return nil, nil
|
||||
}
|
||||
pgt := &Token{
|
||||
Refresh: refresh,
|
||||
}
|
||||
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
|
||||
|
@ -156,17 +165,17 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
|
|||
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
|
||||
*/
|
||||
|
||||
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
||||
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
||||
//
|
||||
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
|
||||
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
|
||||
// then periodically sweep out tokens when that time has passed.
|
||||
//
|
||||
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
|
||||
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
|
||||
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
|
||||
// and implemented here: https://github.com/superseriousbusiness/oauth2/blob/master/models/token.go.
|
||||
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||
// and pgTokenToOauthToken can be used for that.
|
||||
type oauthToken struct {
|
||||
type Token struct {
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||
ClientID string
|
||||
UserID string
|
||||
|
@ -186,7 +195,7 @@ type oauthToken struct {
|
|||
}
|
||||
|
||||
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
|
||||
func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
||||
func oauthTokenToPGToken(tkn *models.Token) *Token {
|
||||
now := time.Now()
|
||||
|
||||
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
|
||||
|
@ -208,7 +217,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
|||
rea = now.Add(tkn.RefreshExpiresIn)
|
||||
}
|
||||
|
||||
return &oauthToken{
|
||||
return &Token{
|
||||
ClientID: tkn.ClientID,
|
||||
UserID: tkn.UserID,
|
||||
RedirectURI: tkn.RedirectURI,
|
||||
|
@ -228,7 +237,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
|||
}
|
||||
|
||||
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
|
||||
func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
|
||||
func pgTokenToOauthToken(pgt *Token) *models.Token {
|
||||
now := time.Now()
|
||||
|
||||
return &models.Token{
|
21
internal/oauth/tokenstore_test.go
Normal file
21
internal/oauth/tokenstore_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
// TODO: write tests
|
44
internal/router/mock_Router.go
Normal file
44
internal/router/mock_Router.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package router
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockRouter is an autogenerated mock type for the Router type
|
||||
type MockRouter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// AttachHandler provides a mock function with given fields: method, path, f
|
||||
func (_m *MockRouter) AttachHandler(method string, path string, f gin.HandlerFunc) {
|
||||
_m.Called(method, path, f)
|
||||
}
|
||||
|
||||
// AttachMiddleware provides a mock function with given fields: handler
|
||||
func (_m *MockRouter) AttachMiddleware(handler gin.HandlerFunc) {
|
||||
_m.Called(handler)
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields:
|
||||
func (_m *MockRouter) Start() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields: ctx
|
||||
func (_m *MockRouter) Stop(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
|
@ -19,62 +19,66 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
// Router provides the REST interface for gotosocial, using gin.
|
||||
type Router interface {
|
||||
// Attach a gin handler to the router with the given method and path
|
||||
AttachHandler(method string, path string, handler gin.HandlerFunc)
|
||||
AttachHandler(method string, path string, f gin.HandlerFunc)
|
||||
// Attach a gin middleware to the router that will be used globally
|
||||
AttachMiddleware(handler gin.HandlerFunc)
|
||||
// Start the router
|
||||
Start()
|
||||
// Stop the router
|
||||
Stop()
|
||||
Stop(ctx context.Context) error
|
||||
}
|
||||
|
||||
// router fulfils the Router interface using gin and logrus
|
||||
type router struct {
|
||||
logger *logrus.Logger
|
||||
engine *gin.Engine
|
||||
srv *http.Server
|
||||
}
|
||||
|
||||
// Start starts the router nicely
|
||||
func (s *router) Start() {
|
||||
// todo: start gracefully
|
||||
if err := s.engine.Run(); err != nil {
|
||||
s.logger.Panicf("server error: %s", err)
|
||||
}
|
||||
func (r *router) Start() {
|
||||
go func() {
|
||||
if err := r.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
r.logger.Fatalf("listen: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop shuts down the router nicely
|
||||
func (s *router) Stop() {
|
||||
// todo: shut down gracefully
|
||||
func (r *router) Stop(ctx context.Context) error {
|
||||
return r.srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
|
||||
// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path.
|
||||
func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
|
||||
func (r *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
|
||||
if method == "ANY" {
|
||||
s.engine.Any(path, handler)
|
||||
r.engine.Any(path, handler)
|
||||
} else {
|
||||
s.engine.Handle(method, path, handler)
|
||||
r.engine.Handle(method, path, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// AttachMiddleware attaches a gin middleware to the router that will be used globally
|
||||
func (s *router) AttachMiddleware(middleware gin.HandlerFunc) {
|
||||
s.engine.Use(middleware)
|
||||
func (r *router) AttachMiddleware(middleware gin.HandlerFunc) {
|
||||
r.engine.Use(middleware)
|
||||
}
|
||||
|
||||
// New returns a new Router with the specified configuration, using the given logrus logger.
|
||||
|
@ -100,6 +104,10 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) {
|
|||
return &router{
|
||||
logger: logger,
|
||||
engine: engine,
|
||||
srv: &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: engine,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
31
internal/storage/inmem.go
Normal file
31
internal/storage/inmem.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
func NewInMem(c *config.Config, log *logrus.Logger) (Storage, error) {
|
||||
return &inMemStorage{
|
||||
stored: make(map[string][]byte),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type inMemStorage struct {
|
||||
stored map[string][]byte
|
||||
}
|
||||
|
||||
func (s *inMemStorage) StoreFileAt(path string, data []byte) error {
|
||||
s.stored[path] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *inMemStorage) RetrieveFileFrom(path string) ([]byte, error) {
|
||||
d, ok := s.stored[path]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no data found at path %s", path)
|
||||
}
|
||||
return d, nil
|
||||
}
|
21
internal/storage/local.go
Normal file
21
internal/storage/local.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
func NewLocal(c *config.Config, log *logrus.Logger) (Storage, error) {
|
||||
return &localStorage{}, nil
|
||||
}
|
||||
|
||||
type localStorage struct {
|
||||
}
|
||||
|
||||
func (s *localStorage) StoreFileAt(path string, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *localStorage) RetrieveFileFrom(path string) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
47
internal/storage/mock_Storage.go
Normal file
47
internal/storage/mock_Storage.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Code generated by mockery v2.7.4. DO NOT EDIT.
|
||||
|
||||
package storage
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockStorage is an autogenerated mock type for the Storage type
|
||||
type MockStorage struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// RetrieveFileFrom provides a mock function with given fields: path
|
||||
func (_m *MockStorage) RetrieveFileFrom(path string) ([]byte, error) {
|
||||
ret := _m.Called(path)
|
||||
|
||||
var r0 []byte
|
||||
if rf, ok := ret.Get(0).(func(string) []byte); ok {
|
||||
r0 = rf(path)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(path)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// StoreFileAt provides a mock function with given fields: path, data
|
||||
func (_m *MockStorage) StoreFileAt(path string, data []byte) error {
|
||||
ret := _m.Called(path, data)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, []byte) error); ok {
|
||||
r0 = rf(path, data)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
24
internal/storage/storage.go
Normal file
24
internal/storage/storage.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package storage
|
||||
|
||||
type Storage interface {
|
||||
StoreFileAt(path string, data []byte) error
|
||||
RetrieveFileFrom(path string) ([]byte, error)
|
||||
}
|
32
internal/util/parse.go
Normal file
32
internal/util/parse.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package util
|
||||
|
||||
import "fmt"
|
||||
|
||||
type URIs struct {
|
||||
HostURL string
|
||||
UserURL string
|
||||
UserURI string
|
||||
InboxURL string
|
||||
OutboxURL string
|
||||
FollowersURL string
|
||||
CollectionURL string
|
||||
}
|
||||
|
||||
func GenerateURIs(username string, protocol string, host string) *URIs {
|
||||
hostURL := fmt.Sprintf("%s://%s", protocol, host)
|
||||
userURL := fmt.Sprintf("%s/@%s", hostURL, username)
|
||||
userURI := fmt.Sprintf("%s/users/%s", hostURL, username)
|
||||
inboxURL := fmt.Sprintf("%s/inbox", userURI)
|
||||
outboxURL := fmt.Sprintf("%s/outbox", userURI)
|
||||
followersURL := fmt.Sprintf("%s/followers", userURI)
|
||||
collectionURL := fmt.Sprintf("%s/collections/featured", userURI)
|
||||
return &URIs{
|
||||
HostURL: hostURL,
|
||||
UserURL: userURL,
|
||||
UserURI: userURI,
|
||||
InboxURL: inboxURL,
|
||||
OutboxURL: outboxURL,
|
||||
FollowersURL: followersURL,
|
||||
CollectionURL: collectionURL,
|
||||
}
|
||||
}
|
144
internal/util/validation.go
Normal file
144
internal/util/validation.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
|
||||
pwv "github.com/wagslane/go-password-validator"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinimumPasswordEntropy dictates password strength. See https://github.com/wagslane/go-password-validator
|
||||
MinimumPasswordEntropy = 60
|
||||
// MinimumReasonLength is the length of chars we expect as a bare minimum effort
|
||||
MinimumReasonLength = 40
|
||||
// MaximumReasonLength is the maximum amount of chars we're happy to accept
|
||||
MaximumReasonLength = 500
|
||||
// MaximumEmailLength is the maximum length of an email address we're happy to accept
|
||||
MaximumEmailLength = 256
|
||||
// MaximumUsernameLength is the maximum length of a username we're happy to accept
|
||||
MaximumUsernameLength = 64
|
||||
// MaximumPasswordLength is the maximum length of a password we're happy to accept
|
||||
MaximumPasswordLength = 64
|
||||
// NewUsernameRegexString is string representation of the regular expression for validating usernames
|
||||
NewUsernameRegexString = `^[a-z0-9_]+$`
|
||||
)
|
||||
|
||||
var (
|
||||
// NewUsernameRegex is the compiled regex for validating new usernames
|
||||
NewUsernameRegex = regexp.MustCompile(NewUsernameRegexString)
|
||||
)
|
||||
|
||||
// ValidateNewPassword returns an error if the given password is not sufficiently strong, or nil if it's ok.
|
||||
func ValidateNewPassword(password string) error {
|
||||
if password == "" {
|
||||
return errors.New("no password provided")
|
||||
}
|
||||
|
||||
if len(password) > MaximumPasswordLength {
|
||||
return fmt.Errorf("password should be no more than %d chars", MaximumPasswordLength)
|
||||
}
|
||||
|
||||
return pwv.Validate(password, MinimumPasswordEntropy)
|
||||
}
|
||||
|
||||
// ValidateUsername makes sure that a given username is valid (ie., letters, numbers, underscores, check length).
|
||||
// Returns an error if not.
|
||||
func ValidateUsername(username string) error {
|
||||
if username == "" {
|
||||
return errors.New("no username provided")
|
||||
}
|
||||
|
||||
if len(username) > MaximumUsernameLength {
|
||||
return fmt.Errorf("username should be no more than %d chars but '%s' was %d", MaximumUsernameLength, username, len(username))
|
||||
}
|
||||
|
||||
if !NewUsernameRegex.MatchString(username) {
|
||||
return fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateEmail makes sure that a given email address is a valid address.
|
||||
// Returns an error if not.
|
||||
func ValidateEmail(email string) error {
|
||||
if email == "" {
|
||||
return errors.New("no email provided")
|
||||
}
|
||||
|
||||
if len(email) > MaximumEmailLength {
|
||||
return fmt.Errorf("email address should be no more than %d chars but '%s' was %d", MaximumEmailLength, email, len(email))
|
||||
}
|
||||
|
||||
_, err := mail.ParseAddress(email)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateLanguage checks that the given language string is a 2- or 3-letter ISO 639 code.
|
||||
// Returns an error if the language cannot be parsed. See: https://pkg.go.dev/golang.org/x/text/language
|
||||
func ValidateLanguage(lang string) error {
|
||||
if lang == "" {
|
||||
return errors.New("no language provided")
|
||||
}
|
||||
_, err := language.ParseBase(lang)
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateSignUpReason checks that a sufficient reason is given for a server signup request
|
||||
func ValidateSignUpReason(reason string, reasonRequired bool) error {
|
||||
if !reasonRequired {
|
||||
// we don't care!
|
||||
// we're not going to do anything with this text anyway if no reason is required
|
||||
return nil
|
||||
}
|
||||
|
||||
if reason == "" {
|
||||
return errors.New("no reason provided")
|
||||
}
|
||||
|
||||
if len(reason) < MinimumReasonLength {
|
||||
return fmt.Errorf("reason should be at least %d chars but '%s' was %d", MinimumReasonLength, reason, len(reason))
|
||||
}
|
||||
|
||||
if len(reason) > MaximumReasonLength {
|
||||
return fmt.Errorf("reason should be no more than %d chars but given reason was %d", MaximumReasonLength, len(reason))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateDisplayName(displayName string) error {
|
||||
// TODO: add some validation logic here -- length, characters, etc
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateNote(note string) error {
|
||||
// TODO: add some validation logic here -- length, characters, etc
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePrivacy(privacy string) error {
|
||||
// TODO: add some validation logic here -- length, characters, etc
|
||||
return nil
|
||||
}
|
288
internal/util/validation_test.go
Normal file
288
internal/util/validation_test.go
Normal file
|
@ -0,0 +1,288 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ValidationTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (suite *ValidationTestSuite) TestCheckPasswordStrength() {
|
||||
empty := ""
|
||||
terriblePassword := "password"
|
||||
weakPassword := "OKPassword"
|
||||
shortPassword := "Ok12"
|
||||
specialPassword := "Ok12%"
|
||||
longPassword := "thisisafuckinglongpasswordbutnospecialchars"
|
||||
tooLong := "Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Quisque a enim nibh. Vestibulum bibendum leo ac porttitor auctor."
|
||||
strongPassword := "3dX5@Zc%mV*W2MBNEy$@"
|
||||
var err error
|
||||
|
||||
err = ValidateNewPassword(empty)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("no password provided"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(terriblePassword)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using uppercase letters, using numbers or using a longer password"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(weakPassword)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using numbers or using a longer password"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(shortPassword)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(specialPassword)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(longPassword)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(tooLong)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("password should be no more than 64 chars"), err)
|
||||
}
|
||||
|
||||
err = ValidateNewPassword(strongPassword)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ValidationTestSuite) TestValidateUsername() {
|
||||
empty := ""
|
||||
tooLong := "holycrapthisisthelongestusernameiveeverseeninmylifethatstoomuchman"
|
||||
withSpaces := "this username has spaces in it"
|
||||
weirdChars := "thisusername&&&&&&&istooweird!!"
|
||||
leadingSpace := " see_that_leading_space"
|
||||
trailingSpace := "thisusername_ends_with_a_space "
|
||||
newlines := "this_is\n_almost_ok"
|
||||
goodUsername := "this_is_a_good_username"
|
||||
var err error
|
||||
|
||||
err = ValidateUsername(empty)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("no username provided"), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(tooLong)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("username should be no more than 64 chars but '%s' was 66", tooLong), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(withSpaces)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", withSpaces), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(weirdChars)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", weirdChars), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(leadingSpace)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", leadingSpace), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(trailingSpace)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", trailingSpace), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(newlines)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", newlines), err)
|
||||
}
|
||||
|
||||
err = ValidateUsername(goodUsername)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ValidationTestSuite) TestValidateEmail() {
|
||||
empty := ""
|
||||
notAnEmailAddress := "this-is-no-email-address!"
|
||||
almostAnEmailAddress := "@thisisalmostan@email.address"
|
||||
aWebsite := "https://thisisawebsite.com"
|
||||
tooLong := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaahhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhggggggggggggggggggggggggggggggggggggggghhhhhhhhhhhhhhhhhggggggggggggggggggggghhhhhhhhhhhhhhhhhhhhhhhhhhhhhh@gmail.com"
|
||||
emailAddress := "thisis.actually@anemail.address"
|
||||
var err error
|
||||
|
||||
err = ValidateEmail(empty)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("no email provided"), err)
|
||||
}
|
||||
|
||||
err = ValidateEmail(notAnEmailAddress)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("mail: missing '@' or angle-addr"), err)
|
||||
}
|
||||
|
||||
err = ValidateEmail(almostAnEmailAddress)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("mail: no angle-addr"), err)
|
||||
}
|
||||
|
||||
err = ValidateEmail(aWebsite)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("mail: missing '@' or angle-addr"), err)
|
||||
}
|
||||
|
||||
err = ValidateEmail(tooLong)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), fmt.Errorf("email address should be no more than 256 chars but '%s' was 286", tooLong), err)
|
||||
}
|
||||
|
||||
err = ValidateEmail(emailAddress)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ValidationTestSuite) TestValidateLanguage() {
|
||||
empty := ""
|
||||
notALanguage := "this isn't a language at all!"
|
||||
english := "en"
|
||||
capitalEnglish := "EN"
|
||||
arabic3Letters := "ara"
|
||||
mixedCapsEnglish := "eN"
|
||||
englishUS := "en-us"
|
||||
dutch := "nl"
|
||||
german := "de"
|
||||
var err error
|
||||
|
||||
err = ValidateLanguage(empty)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("no language provided"), err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(notALanguage)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("language: tag is not well-formed"), err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(english)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(capitalEnglish)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(arabic3Letters)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(mixedCapsEnglish)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(englishUS)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("language: tag is not well-formed"), err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(dutch)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateLanguage(german)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ValidationTestSuite) TestValidateReason() {
|
||||
empty := ""
|
||||
badReason := "because"
|
||||
goodReason := "to smash the state and destroy capitalism ultimately and completely"
|
||||
tooLong := "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Mauris auctor mollis viverra. Maecenas maximus mollis sem, nec fermentum velit consectetur non. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Quisque a enim nibh. Vestibulum bibendum leo ac porttitor auctor. Curabitur velit tellus, facilisis vitae lorem a, ullamcorper efficitur leo. Sed a auctor tortor. Sed ut finibus ante, sit amet laoreet sapien. Donec ullamcorper tellus a nibh sodales vulputate. Donec id dolor eu odio mollis bibendum. Pellentesque habitant morbi tristique senectus et netus at."
|
||||
var err error
|
||||
|
||||
// check with no reason required
|
||||
err = ValidateSignUpReason(empty, false)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(badReason, false)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(tooLong, false)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(goodReason, false)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
|
||||
// check with reason required
|
||||
err = ValidateSignUpReason(empty, true)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("no reason provided"), err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(badReason, true)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("reason should be at least 40 chars but 'because' was 7"), err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(tooLong, true)
|
||||
if assert.Error(suite.T(), err) {
|
||||
assert.Equal(suite.T(), errors.New("reason should be no more than 500 chars but given reason was 600"), err)
|
||||
}
|
||||
|
||||
err = ValidateSignUpReason(goodReason, true)
|
||||
if assert.NoError(suite.T(), err) {
|
||||
assert.Equal(suite.T(), nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ValidationTestSuite))
|
||||
}
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
package mastotypes
|
||||
|
||||
import "mime/multipart"
|
||||
|
||||
// Account represents a mastodon-api Account object, as described here: https://docs.joinmastodon.org/entities/account/
|
||||
type Account struct {
|
||||
// The account id
|
||||
|
@ -31,7 +33,7 @@ type Account struct {
|
|||
// Whether the account manually approves follow requests.
|
||||
Locked bool `json:"locked"`
|
||||
// Whether the account has opted into discovery features such as the profile directory.
|
||||
Discoverable bool `json:"discoverable"`
|
||||
Discoverable bool `json:"discoverable,omitempty"`
|
||||
// A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot.
|
||||
Bot bool `json:"bot"`
|
||||
// When the account was created. (ISO 8601 Datetime)
|
||||
|
@ -61,9 +63,69 @@ type Account struct {
|
|||
// Additional metadata attached to a profile as name-value pairs.
|
||||
Fields []Field `json:"fields"`
|
||||
// An extra entity returned when an account is suspended.
|
||||
Suspended bool `json:"suspended"`
|
||||
Suspended bool `json:"suspended,omitempty"`
|
||||
// When a timed mute will expire, if applicable. (ISO 8601 Datetime)
|
||||
MuteExpiresAt string `json:"mute_expires_at"`
|
||||
MuteExpiresAt string `json:"mute_expires_at,omitempty"`
|
||||
// An extra entity to be used with API methods to verify credentials and update credentials.
|
||||
Source *Source `json:"source"`
|
||||
}
|
||||
|
||||
// AccountCreateRequest represents the form submitted during a POST request to /api/v1/accounts.
|
||||
// See https://docs.joinmastodon.org/methods/accounts/
|
||||
type AccountCreateRequest struct {
|
||||
// Text that will be reviewed by moderators if registrations require manual approval.
|
||||
Reason string `form:"reason"`
|
||||
// The desired username for the account
|
||||
Username string `form:"username" binding:"required"`
|
||||
// The email address to be used for login
|
||||
Email string `form:"email" binding:"required"`
|
||||
// The password to be used for login
|
||||
Password string `form:"password" binding:"required"`
|
||||
// Whether the user agrees to the local rules, terms, and policies.
|
||||
// These should be presented to the user in order to allow them to consent before setting this parameter to TRUE.
|
||||
Agreement bool `form:"agreement" binding:"required"`
|
||||
// The language of the confirmation email that will be sent
|
||||
Locale string `form:"locale" binding:"required"`
|
||||
}
|
||||
|
||||
// UpdateCredentialsRequest represents the form submitted during a PATCH request to /api/v1/accounts/update_credentials.
|
||||
// See https://docs.joinmastodon.org/methods/accounts/
|
||||
type UpdateCredentialsRequest struct {
|
||||
// Whether the account should be shown in the profile directory.
|
||||
Discoverable *bool `form:"discoverable"`
|
||||
// Whether the account has a bot flag.
|
||||
Bot *bool `form:"bot"`
|
||||
// The display name to use for the profile.
|
||||
DisplayName *string `form:"display_name"`
|
||||
// The account bio.
|
||||
Note *string `form:"note"`
|
||||
// Avatar image encoded using multipart/form-data
|
||||
Avatar *multipart.FileHeader `form:"avatar"`
|
||||
// Header image encoded using multipart/form-data
|
||||
Header *multipart.FileHeader `form:"header"`
|
||||
// Whether manual approval of follow requests is required.
|
||||
Locked *bool `form:"locked"`
|
||||
// New Source values for this account
|
||||
Source *UpdateSource `form:"source"`
|
||||
// Profile metadata name and value
|
||||
FieldsAttributes *[]UpdateField `form:"fields_attributes"`
|
||||
}
|
||||
|
||||
// UpdateSource is to be used specifically in an UpdateCredentialsRequest.
|
||||
type UpdateSource struct {
|
||||
// Default post privacy for authored statuses.
|
||||
Privacy *string `form:"privacy"`
|
||||
// Whether to mark authored statuses as sensitive by default.
|
||||
Sensitive *bool `form:"sensitive"`
|
||||
// Default language to use for authored statuses. (ISO 6391)
|
||||
Language *string `form:"language"`
|
||||
}
|
||||
|
||||
// UpdateField is to be used specifically in an UpdateCredentialsRequest.
|
||||
// By default, max 4 fields and 255 characters per property/value.
|
||||
type UpdateField struct {
|
||||
// Name of the field
|
||||
Name *string `form:"name"`
|
||||
// Value of the field
|
||||
Value *string `form:"value"`
|
||||
}
|
||||
|
|
|
@ -43,11 +43,11 @@ type Application struct {
|
|||
// And here: https://docs.joinmastodon.org/client/token/
|
||||
type ApplicationPOSTRequest struct {
|
||||
// A name for your application
|
||||
ClientName string `form:"client_name"`
|
||||
ClientName string `form:"client_name" binding:"required"`
|
||||
// Where the user should be redirected after authorization.
|
||||
// To display the authorization code to the user instead of redirecting
|
||||
// to a web page, use urn:ietf:wg:oauth:2.0:oob in this parameter.
|
||||
RedirectURIs string `form:"redirect_uris"`
|
||||
RedirectURIs string `form:"redirect_uris" binding:"required"`
|
||||
// Space separated list of scopes. If none is provided, defaults to read.
|
||||
Scopes string `form:"scopes"`
|
||||
// A URL to the homepage of your app
|
||||
|
|
|
@ -28,7 +28,6 @@ type Field struct {
|
|||
Value string `json:"value"`
|
||||
|
||||
// OPTIONAL
|
||||
|
||||
// Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL
|
||||
VerifiedAt string `json:"verified_at,omitempty"`
|
||||
}
|
||||
|
|
|
@ -18,5 +18,24 @@
|
|||
|
||||
package mastotypes
|
||||
|
||||
// Source represents display or publishing preferences of user's own account.
|
||||
// Returned as an additional entity when verifying and updated credentials, as an attribute of Account.
|
||||
// See https://docs.joinmastodon.org/entities/source/
|
||||
type Source struct {
|
||||
// The default post privacy to be used for new statuses.
|
||||
// public = Public post
|
||||
// unlisted = Unlisted post
|
||||
// private = Followers-only post
|
||||
// direct = Direct post
|
||||
Privacy string `json:"privacy,omitempty"`
|
||||
// Whether new statuses should be marked sensitive by default.
|
||||
Sensitive bool `json:"sensitive,omitempty"`
|
||||
// The default posting language for new statuses.
|
||||
Language string `json:"language,omitempty"`
|
||||
// Profile bio.
|
||||
Note string `json:"note"`
|
||||
// Metadata about the account.
|
||||
Fields []Field `json:"fields"`
|
||||
// The number of pending follow requests.
|
||||
FollowRequestsCount int `json:"follow_requests_count,omitempty"`
|
||||
}
|
||||
|
|
|
@ -18,5 +18,6 @@
|
|||
|
||||
package mastotypes
|
||||
|
||||
// Tag represents a hashtag used within the content of a status. See https://docs.joinmastodon.org/entities/tag/
|
||||
type Tag struct {
|
||||
}
|
||||
|
|
31
pkg/mastotypes/token.go
Normal file
31
pkg/mastotypes/token.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package mastotypes
|
||||
|
||||
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
|
||||
type Token struct {
|
||||
// An OAuth token to be used for authorization.
|
||||
AccessToken string `json:"access_token"`
|
||||
// The OAuth token type. Mastodon uses Bearer tokens.
|
||||
TokenType string `json:"token_type"`
|
||||
// The OAuth scopes granted by this token, space-separated.
|
||||
Scope string `json:"scope"`
|
||||
// When the token was generated. (UNIX timestamp seconds)
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
33
scripts/auth_flow.sh
Executable file
33
scripts/auth_flow.sh
Executable file
|
@ -0,0 +1,33 @@
|
|||
#!/bin/sh
|
||||
|
||||
set -eux
|
||||
|
||||
SERVER_URL="http://localhost:8080"
|
||||
REDIRECT_URI="${SERVER_URL}"
|
||||
CLIENT_NAME="Test Application Name"
|
||||
|
||||
REGISTRATION_REASON="Testing whether or not this dang diggity thing works!"
|
||||
REGISTRATION_EMAIL="test@example.org"
|
||||
REGISTRATION_USERNAME="test_user"
|
||||
REGISTRATION_PASSWORD="very safe password 123"
|
||||
REGISTRATION_AGREEMENT="true"
|
||||
REGISTRATION_LOCALE="en"
|
||||
|
||||
# Step 1: create the app to register the new account
|
||||
CREATE_APP_RESPONSE=$(curl --fail -s -X POST -F "client_name=${CLIENT_NAME}" -F "redirect_uris=${REDIRECT_URI}" "${SERVER_URL}/api/v1/apps")
|
||||
CLIENT_ID=$(echo "${CREATE_APP_RESPONSE}" | jq -r .client_id)
|
||||
CLIENT_SECRET=$(echo "${CREATE_APP_RESPONSE}" | jq -r .client_secret)
|
||||
echo "Obtained client_id: ${CLIENT_ID} and client_secret: ${CLIENT_SECRET}"
|
||||
|
||||
# Step 2: obtain a code for that app
|
||||
APP_CODE_RESPONSE=$(curl --fail -s -X POST -F "scope=read" -F "grant_type=client_credentials" -F "client_id=${CLIENT_ID}" -F "client_secret=${CLIENT_SECRET}" -F "redirect_uri=${REDIRECT_URI}" "${SERVER_URL}/oauth/token")
|
||||
APP_ACCESS_TOKEN=$(echo "${APP_CODE_RESPONSE}" | jq -r .access_token)
|
||||
echo "Obtained app access token: ${APP_ACCESS_TOKEN}"
|
||||
|
||||
# Step 3: use the code to register a new account
|
||||
ACCOUNT_REGISTER_RESPONSE=$(curl --fail -s -H "Authorization: Bearer ${APP_ACCESS_TOKEN}" -F "reason=${REGISTRATION_REASON}" -F "email=${REGISTRATION_EMAIL}" -F "username=${REGISTRATION_USERNAME}" -F "password=${REGISTRATION_PASSWORD}" -F "agreement=${REGISTRATION_AGREEMENT}" -F "locale=${REGISTRATION_LOCALE}" "${SERVER_URL}/api/v1/accounts")
|
||||
USER_ACCESS_TOKEN=$(echo "${ACCOUNT_REGISTER_RESPONSE}" | jq -r .access_token)
|
||||
echo "Obtained user access token: ${USER_ACCESS_TOKEN}"
|
||||
|
||||
# # Step 4: verify the returned access token
|
||||
curl -s -H "Authorization: Bearer ${USER_ACCESS_TOKEN}" "${SERVER_URL}/api/v1/accounts/verify_credentials" | jq
|
Loading…
Reference in a new issue