mirror of
https://github.com/matrix-org/dendrite
synced 2025-01-19 00:13:59 +00:00
Add context to the account database (#232)
This commit is contained in:
parent
5ada8872bb
commit
e28ee27605
21 changed files with 267 additions and 138 deletions
|
@ -15,6 +15,7 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -70,17 +71,22 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) {
|
func (s *accountDataStatements) insertAccountData(
|
||||||
_, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content)
|
ctx context.Context, localpart, roomID, dataType, content string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := s.insertAccountDataStmt
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountData(localpart string) (
|
func (s *accountDataStatements) selectAccountData(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
global []gomatrixserverlib.ClientEvent,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.Query(localpart)
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -93,7 +99,7 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
|
||||||
var dataType string
|
var dataType string
|
||||||
var content []byte
|
var content []byte
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows {
|
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,11 +119,12 @@ func (s *accountDataStatements) selectAccountData(localpart string) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataByType(
|
func (s *accountDataStatements) selectAccountDataByType(
|
||||||
localpart string, roomID string, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data []gomatrixserverlib.ClientEvent, err error) {
|
) (data []gomatrixserverlib.ClientEvent, err error) {
|
||||||
data = []gomatrixserverlib.ClientEvent{}
|
data = []gomatrixserverlib.ClientEvent{}
|
||||||
|
|
||||||
rows, err := s.selectAccountDataByTypeStmt.Query(localpart, roomID, dataType)
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, localpart, roomID, dataType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
@ -76,26 +77,34 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(localpart, hash string) (acc *authtypes.Account, err error) {
|
func (s *accountsStatements) insertAccount(
|
||||||
|
ctx context.Context, localpart, hash string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
if _, err = s.insertAccountStmt.Exec(localpart, createdTimeMS, hash); err == nil {
|
stmt := s.insertAccountStmt
|
||||||
acc = &authtypes.Account{
|
if _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &authtypes.Account{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
UserID: makeUserID(localpart, s.serverName),
|
UserID: makeUserID(localpart, s.serverName),
|
||||||
ServerName: s.serverName,
|
ServerName: s.serverName,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
func (s *accountsStatements) selectPasswordHash(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (hash string, err error) {
|
||||||
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectPasswordHash(localpart string) (hash string, err error) {
|
func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
err = s.selectPasswordHashStmt.QueryRow(localpart).Scan(&hash)
|
ctx context.Context, localpart string,
|
||||||
return
|
) (*authtypes.Account, error) {
|
||||||
}
|
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(localpart string) (*authtypes.Account, error) {
|
|
||||||
var acc authtypes.Account
|
var acc authtypes.Account
|
||||||
err := s.selectAccountByLocalpartStmt.QueryRow(localpart).Scan(&acc.Localpart)
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
acc.UserID = makeUserID(localpart, s.serverName)
|
acc.UserID = makeUserID(localpart, s.serverName)
|
||||||
acc.ServerName = s.serverName
|
acc.ServerName = s.serverName
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -80,18 +81,27 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) {
|
func (s *membershipStatements) insertMembership(
|
||||||
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID)
|
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := txn.Stmt(s.insertMembershipStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) {
|
func (s *membershipStatements) deleteMembershipsByEventIDs(
|
||||||
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs))
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) {
|
func (s *membershipStatements) selectMembershipsByLocalpart(
|
||||||
rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart)
|
ctx context.Context, localpart string,
|
||||||
|
) (memberships []authtypes.Membership, err error) {
|
||||||
|
stmt := s.selectMembershipsByLocalpartStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -111,7 +121,11 @@ func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (m
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) updateMembershipByEventID(oldEventID string, newEventID string) (err error) {
|
func (s *membershipStatements) updateMembershipByEventID(
|
||||||
_, err = s.updateMembershipByEventIDStmt.Exec(oldEventID, newEventID)
|
ctx context.Context, oldEventID string, newEventID string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.updateMembershipByEventIDStmt.ExecContext(
|
||||||
|
ctx, oldEventID, newEventID,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
@ -71,23 +72,36 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) insertProfile(localpart string) (err error) {
|
func (s *profilesStatements) insertProfile(
|
||||||
_, err = s.insertProfileStmt.Exec(localpart, "", "")
|
ctx context.Context, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfileByLocalpart(localpart string) (*authtypes.Profile, error) {
|
func (s *profilesStatements) selectProfileByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRow(localpart).Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL)
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||||
return &profile, err
|
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setAvatarURL(localpart string, avatarURL string) (err error) {
|
func (s *profilesStatements) setAvatarURL(
|
||||||
_, err = s.setAvatarURLStmt.Exec(avatarURL, localpart)
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setDisplayName(localpart string, displayName string) (err error) {
|
func (s *profilesStatements) setDisplayName(
|
||||||
_, err = s.setDisplayNameStmt.Exec(displayName, localpart)
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
@ -74,46 +75,56 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByPassword(localpart, plaintextPassword string) (*authtypes.Account, error) {
|
func (d *Database) GetAccountByPassword(
|
||||||
hash, err := d.accounts.selectPasswordHash(localpart)
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
|
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.accounts.selectAccountByLocalpart(localpart)
|
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(localpart string) (*authtypes.Profile, error) {
|
func (d *Database) GetProfileByLocalpart(
|
||||||
return d.profiles.selectProfileByLocalpart(localpart)
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
|
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetAvatarURL(localpart string, avatarURL string) error {
|
func (d *Database) SetAvatarURL(
|
||||||
return d.profiles.setAvatarURL(localpart, avatarURL)
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetDisplayName(localpart string, displayName string) error {
|
func (d *Database) SetDisplayName(
|
||||||
return d.profiles.setDisplayName(localpart, displayName)
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) error {
|
||||||
|
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account.
|
// for this account. If no password is supplied, the account will be a passwordless account.
|
||||||
func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtypes.Account, error) {
|
func (d *Database) CreateAccount(
|
||||||
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) (*authtypes.Account, error) {
|
||||||
hash, err := hashPassword(plaintextPassword)
|
hash, err := hashPassword(plaintextPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := d.profiles.insertProfile(localpart); err != nil {
|
if err := d.profiles.insertProfile(ctx, localpart); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.accounts.insertAccount(localpart, hash)
|
return d.accounts.insertAccount(ctx, localpart, hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartitionOffsets implements common.PartitionStorer
|
// PartitionOffsets implements common.PartitionStorer
|
||||||
|
@ -131,15 +142,19 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
|
||||||
// is still in the room.
|
// is still in the room.
|
||||||
// If a membership already exists between the user and the room, or of the
|
// If a membership already exists between the user and the room, or of the
|
||||||
// insert fails, returns the SQL error
|
// insert fails, returns the SQL error
|
||||||
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
|
func (d *Database) saveMembership(
|
||||||
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
|
ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
|
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
|
||||||
// event ID is included in a given array of events IDs
|
// event ID is included in a given array of events IDs
|
||||||
// If the removal fails, or if there is no membership to remove, returns an error
|
// If the removal fails, or if there is no membership to remove, returns an error
|
||||||
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
func (d *Database) removeMembershipsByEventIDs(
|
||||||
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
|
) error {
|
||||||
|
return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMemberships adds the "join" membership events included in a given state
|
// UpdateMemberships adds the "join" membership events included in a given state
|
||||||
|
@ -147,14 +162,16 @@ func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) e
|
||||||
// IDs. All of the process is run in a transaction, which commits only once/if every
|
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||||
// insertion and deletion has been successfully processed.
|
// insertion and deletion has been successfully processed.
|
||||||
// Returns a SQL error if there was an issue with any part of the process
|
// Returns a SQL error if there was an issue with any part of the process
|
||||||
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
|
func (d *Database) UpdateMemberships(
|
||||||
|
ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
|
||||||
|
) error {
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
|
if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, event := range eventsToAdd {
|
for _, event := range eventsToAdd {
|
||||||
if err := d.newMembership(event, txn); err != nil {
|
if err := d.newMembership(ctx, txn, event); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,8 +184,10 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT
|
||||||
// the rooms a user matching a given localpart is a member of
|
// the rooms a user matching a given localpart is a member of
|
||||||
// If no membership match the given localpart, returns an empty array
|
// If no membership match the given localpart, returns an empty array
|
||||||
// If there was an issue during the retrieval, returns the SQL error
|
// If there was an issue during the retrieval, returns the SQL error
|
||||||
func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) {
|
func (d *Database) GetMembershipsByLocalpart(
|
||||||
return d.memberships.selectMembershipsByLocalpart(localpart)
|
ctx context.Context, localpart string,
|
||||||
|
) (memberships []authtypes.Membership, err error) {
|
||||||
|
return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMembership will save a new membership in the database, with a flag on whether
|
// newMembership will save a new membership in the database, with a flag on whether
|
||||||
|
@ -178,7 +197,9 @@ func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []au
|
||||||
// values, does nothing.
|
// values, does nothing.
|
||||||
// If the event isn't a "join" membership event, does nothing
|
// If the event isn't a "join" membership event, does nothing
|
||||||
// If an error occurred, returns it
|
// If an error occurred, returns it
|
||||||
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
func (d *Database) newMembership(
|
||||||
|
ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
|
||||||
|
) error {
|
||||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -199,7 +220,7 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
|
||||||
|
|
||||||
// Only "join" membership events can be considered as new memberships
|
// Only "join" membership events can be considered as new memberships
|
||||||
if membership == "join" {
|
if membership == "join" {
|
||||||
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -212,27 +233,33 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error
|
||||||
// If an account data already exists for a given set (user, room, data type), it will
|
// If an account data already exists for a given set (user, room, data type), it will
|
||||||
// update the corresponding row with the new content
|
// update the corresponding row with the new content
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error {
|
func (d *Database) SaveAccountData(
|
||||||
return d.accountDatas.insertAccountData(localpart, roomID, dataType, content)
|
ctx context.Context, localpart, roomID, dataType, content string,
|
||||||
|
) error {
|
||||||
|
return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
// GetAccountData returns account data related to a given localpart
|
||||||
// If no account data could be found, returns an empty arrays
|
// If no account data could be found, returns an empty arrays
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountData(localpart string) (
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
global []gomatrixserverlib.ClientEvent,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
return d.accountDatas.selectAccountData(localpart)
|
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountDataByType returns account data matching a given
|
// GetAccountDataByType returns account data matching a given
|
||||||
// localpart, room ID and type.
|
// localpart, room ID and type.
|
||||||
// If no account data could be found, returns an empty array
|
// If no account data could be found, returns an empty array
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountDataByType(localpart string, roomID string, dataType string) (data []gomatrixserverlib.ClientEvent, err error) {
|
func (d *Database) GetAccountDataByType(
|
||||||
return d.accountDatas.selectAccountDataByType(localpart, roomID, dataType)
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
|
) (data []gomatrixserverlib.ClientEvent, err error) {
|
||||||
|
return d.accountDatas.selectAccountDataByType(
|
||||||
|
ctx, localpart, roomID, dataType,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
@ -248,9 +275,13 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use")
|
||||||
// and a local Matrix user (identified by the user's ID's local part).
|
// and a local Matrix user (identified by the user's ID's local part).
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, medium string) (err error) {
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid, localpart, medium string,
|
||||||
|
) (err error) {
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
user, err := d.threepids.selectLocalpartForThreePID(txn, threepid, medium)
|
user, err := d.threepids.selectLocalpartForThreePID(
|
||||||
|
ctx, txn, threepid, medium,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -259,7 +290,7 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
|
||||||
return Err3PIDInUse
|
return Err3PIDInUse
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.threepids.insertThreePID(txn, threepid, medium, localpart)
|
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,8 +298,10 @@ func (d *Database) SaveThreePIDAssociation(threepid string, localpart string, me
|
||||||
// identifier.
|
// identifier.
|
||||||
// If no association exists involving this third-party identifier, returns nothing.
|
// If no association exists involving this third-party identifier, returns nothing.
|
||||||
// If there was a problem talking to the database, returns an error.
|
// If there was a problem talking to the database, returns an error.
|
||||||
func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (err error) {
|
func (d *Database) RemoveThreePIDAssociation(
|
||||||
return d.threepids.deleteThreePID(threepid, medium)
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||||
|
@ -276,14 +309,18 @@ func (d *Database) RemoveThreePIDAssociation(threepid string, medium string) (er
|
||||||
// If no association involves the given third-party idenfitier, returns an empty
|
// If no association involves the given third-party idenfitier, returns an empty
|
||||||
// string.
|
// string.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) GetLocalpartForThreePID(threepid string, medium string) (localpart string, err error) {
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
return d.threepids.selectLocalpartForThreePID(nil, threepid, medium)
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||||
// a given local user.
|
// a given local user.
|
||||||
// If no association is known for this user, returns an empty slice.
|
// If no association is known for this user, returns an empty slice.
|
||||||
// Returns an error if there was an issue talking to the database.
|
// Returns an error if there was an issue talking to the database.
|
||||||
func (d *Database) GetThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) {
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
return d.threepids.selectThreePIDsForLocalpart(localpart)
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,11 @@
|
||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,22 +79,21 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectLocalpartForThreePID(txn *sql.Tx, threepid string, medium string) (localpart string, err error) {
|
func (s *threepidStatements) selectLocalpartForThreePID(
|
||||||
var stmt *sql.Stmt
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
if txn != nil {
|
) (localpart string, err error) {
|
||||||
stmt = txn.Stmt(s.selectLocalpartForThreePIDStmt)
|
stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
} else {
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||||
stmt = s.selectLocalpartForThreePIDStmt
|
|
||||||
}
|
|
||||||
err = stmt.QueryRow(threepid, medium).Scan(&localpart)
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (threepids []authtypes.ThreePID, err error) {
|
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.Query(localpart)
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -103,18 +105,25 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(localpart string) (thre
|
||||||
if err = rows.Scan(&threepid, &medium); err != nil {
|
if err = rows.Scan(&threepid, &medium); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
threepids = append(threepids, authtypes.ThreePID{threepid, medium})
|
threepids = append(threepids, authtypes.ThreePID{
|
||||||
|
Address: threepid,
|
||||||
|
Medium: medium,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) insertThreePID(txn *sql.Tx, threepid string, medium string, localpart string) (err error) {
|
func (s *threepidStatements) insertThreePID(
|
||||||
_, err = txn.Stmt(s.insertThreePIDStmt).Exec(threepid, medium, localpart)
|
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) deleteThreePID(threepid string, medium string) (err error) {
|
func (s *threepidStatements) deleteThreePID(
|
||||||
_, err = s.deleteThreePIDStmt.Exec(threepid, medium)
|
ctx context.Context, threepid string, medium string) (err error) {
|
||||||
|
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
|
if err := s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,9 @@ func SaveAccountData(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil {
|
if err := accountDB.SaveAccountData(
|
||||||
|
req.Context(), localpart, roomID, dataType, string(body),
|
||||||
|
); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ func Login(
|
||||||
|
|
||||||
util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request")
|
util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request")
|
||||||
|
|
||||||
acc, err := accountDB.GetAccountByPassword(r.User, r.Password)
|
acc, err := accountDB.GetAccountByPassword(req.Context(), r.User, r.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||||
// but that would leak the existence of the user.
|
// but that would leak the existence of the user.
|
||||||
|
|
|
@ -60,7 +60,7 @@ func GetProfile(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -83,7 +83,7 @@ func GetAvatarURL(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -127,16 +127,16 @@ func SetAvatarURL(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldProfile, err := accountDB.GetProfileByLocalpart(localpart)
|
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil {
|
if err = accountDB.SetAvatarURL(req.Context(), localpart, r.AvatarURL); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
memberships, err := accountDB.GetMembershipsByLocalpart(localpart)
|
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -175,7 +175,7 @@ func GetDisplayName(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -219,16 +219,16 @@ func SetDisplayName(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldProfile, err := accountDB.GetProfileByLocalpart(localpart)
|
oldProfile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = accountDB.SetDisplayName(localpart, r.DisplayName); err != nil {
|
if err = accountDB.SetDisplayName(req.Context(), localpart, r.DisplayName); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
memberships, err := accountDB.GetMembershipsByLocalpart(localpart)
|
memberships, err := accountDB.GetMembershipsByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Check if the 3PID is already in use locally
|
// Check if the 3PID is already in use locally
|
||||||
localpart, err := accountDB.GetLocalpartForThreePID(body.Email, "email")
|
localpart, err := accountDB.GetLocalpartForThreePID(req.Context(), body.Email, "email")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg conf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SID, err = threepid.CreateSession(body, cfg)
|
resp.SID, err = threepid.CreateSession(req.Context(), body, cfg)
|
||||||
if err == threepid.ErrNotTrusted {
|
if err == threepid.ErrNotTrusted {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 400,
|
Code: 400,
|
||||||
|
@ -91,7 +91,7 @@ func CheckAndSave3PIDAssociation(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the association has been validated
|
// Check if the association has been validated
|
||||||
verified, address, medium, err := threepid.CheckAssociation(body.Creds, cfg)
|
verified, address, medium, err := threepid.CheckAssociation(req.Context(), body.Creds, cfg)
|
||||||
if err == threepid.ErrNotTrusted {
|
if err == threepid.ErrNotTrusted {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 400,
|
Code: 400,
|
||||||
|
@ -130,7 +130,7 @@ func CheckAndSave3PIDAssociation(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = accountDB.SaveThreePIDAssociation(address, localpart, medium); err != nil {
|
if err = accountDB.SaveThreePIDAssociation(req.Context(), address, localpart, medium); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ func GetAssociated3PIDs(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
threepids, err := accountDB.GetThreePIDsForLocalpart(localpart)
|
threepids, err := accountDB.GetThreePIDsForLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONRespon
|
||||||
return *reqErr
|
return *reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := accountDB.RemoveThreePIDAssociation(body.Address, body.Medium); err != nil {
|
if err := accountDB.RemoveThreePIDAssociation(req.Context(), body.Address, body.Medium); err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ func CheckAndProcessInvite(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lookupRes, storeInviteRes, err := queryIDServer(db, cfg, device, body, roomID)
|
lookupRes, storeInviteRes, err := queryIDServer(ctx, db, cfg, device, body, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -134,6 +134,7 @@ func CheckAndProcessInvite(
|
||||||
// Returns a representation of the response for both cases.
|
// Returns a representation of the response for both cases.
|
||||||
// Returns an error if a check or a request failed.
|
// Returns an error if a check or a request failed.
|
||||||
func queryIDServer(
|
func queryIDServer(
|
||||||
|
ctx context.Context,
|
||||||
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
||||||
|
@ -142,7 +143,7 @@ func queryIDServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup the 3PID
|
// Lookup the 3PID
|
||||||
lookupRes, err = queryIDServerLookup(body)
|
lookupRes, err = queryIDServerLookup(ctx, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -150,7 +151,7 @@ func queryIDServer(
|
||||||
if lookupRes.MXID == "" {
|
if lookupRes.MXID == "" {
|
||||||
// No Matrix ID matches with the given 3PID, ask the server to store the
|
// No Matrix ID matches with the given 3PID, ask the server to store the
|
||||||
// invite and return a token
|
// invite and return a token
|
||||||
storeInviteRes, err = queryIDServerStoreInvite(db, cfg, device, body, roomID)
|
storeInviteRes, err = queryIDServerStoreInvite(ctx, db, cfg, device, body, roomID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,11 +162,11 @@ func queryIDServer(
|
||||||
if lookupRes.NotBefore > now || now > lookupRes.NotAfter {
|
if lookupRes.NotBefore > now || now > lookupRes.NotAfter {
|
||||||
// If the current timestamp isn't in the time frame in which the association
|
// If the current timestamp isn't in the time frame in which the association
|
||||||
// is known to be valid, re-run the query
|
// is known to be valid, re-run the query
|
||||||
return queryIDServer(db, cfg, device, body, roomID)
|
return queryIDServer(ctx, db, cfg, device, body, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the request signatures and send an error if one isn't valid
|
// Check the request signatures and send an error if one isn't valid
|
||||||
if err = checkIDServerSignatures(body, lookupRes); err != nil {
|
if err = checkIDServerSignatures(ctx, body, lookupRes); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,10 +176,14 @@ func queryIDServer(
|
||||||
// queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup
|
// queryIDServerLookup sends a response to the identity server on /_matrix/identity/api/v1/lookup
|
||||||
// and returns the response as a structure.
|
// and returns the response as a structure.
|
||||||
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
||||||
func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, error) {
|
func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServerLookupResponse, error) {
|
||||||
address := url.QueryEscape(body.Address)
|
address := url.QueryEscape(body.Address)
|
||||||
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address)
|
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/lookup?medium=%s&address=%s", body.IDServer, body.Medium, address)
|
||||||
resp, err := http.Get(url)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -198,6 +203,7 @@ func queryIDServerLookup(body *MembershipRequest) (*idServerLookupResponse, erro
|
||||||
// and returns the response as a structure.
|
// and returns the response as a structure.
|
||||||
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
||||||
func queryIDServerStoreInvite(
|
func queryIDServerStoreInvite(
|
||||||
|
ctx context.Context,
|
||||||
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
db *accounts.Database, cfg config.Dendrite, device *authtypes.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (*idServerStoreInviteResponse, error) {
|
) (*idServerStoreInviteResponse, error) {
|
||||||
|
@ -209,7 +215,7 @@ func queryIDServerStoreInvite(
|
||||||
|
|
||||||
var profile *authtypes.Profile
|
var profile *authtypes.Profile
|
||||||
if serverName == cfg.Matrix.ServerName {
|
if serverName == cfg.Matrix.ServerName {
|
||||||
profile, err = db.GetProfileByLocalpart(localpart)
|
profile, err = db.GetProfileByLocalpart(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -239,7 +245,7 @@ func queryIDServerStoreInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -259,9 +265,13 @@ func queryIDServerStoreInvite(
|
||||||
// We assume that the ID server is trusted at this point.
|
// We assume that the ID server is trusted at this point.
|
||||||
// Returns an error if the request couldn't be sent, if its body couldn't be parsed
|
// Returns an error if the request couldn't be sent, if its body couldn't be parsed
|
||||||
// or if the key couldn't be decoded from base64.
|
// or if the key couldn't be decoded from base64.
|
||||||
func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) {
|
func queryIDServerPubKey(ctx context.Context, idServerName string, keyID string) ([]byte, error) {
|
||||||
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID)
|
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/%s", idServerName, keyID)
|
||||||
resp, err := http.Get(url)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -286,7 +296,9 @@ func queryIDServerPubKey(idServerName string, keyID string) ([]byte, error) {
|
||||||
// We assume that the ID server is trusted at this point.
|
// We assume that the ID server is trusted at this point.
|
||||||
// Returns nil if all the verifications succeeded.
|
// Returns nil if all the verifications succeeded.
|
||||||
// Returns an error if something failed in the process.
|
// Returns an error if something failed in the process.
|
||||||
func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupResponse) error {
|
func checkIDServerSignatures(
|
||||||
|
ctx context.Context, body *MembershipRequest, res *idServerLookupResponse,
|
||||||
|
) error {
|
||||||
// Mashall the body so we can give it to VerifyJSON
|
// Mashall the body so we can give it to VerifyJSON
|
||||||
marshalledBody, err := json.Marshal(*res)
|
marshalledBody, err := json.Marshal(*res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -299,7 +311,7 @@ func checkIDServerSignatures(body *MembershipRequest, res *idServerLookupRespons
|
||||||
}
|
}
|
||||||
|
|
||||||
for keyID := range signatures {
|
for keyID := range signatures {
|
||||||
pubKey, err := queryIDServerPubKey(body.IDServer, keyID)
|
pubKey, err := queryIDServerPubKey(ctx, body.IDServer, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package threepid
|
package threepid
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -51,7 +52,9 @@ type Credentials struct {
|
||||||
// Returns the session's ID.
|
// Returns the session's ID.
|
||||||
// Returns an error if there was a problem sending the request or decoding the
|
// Returns an error if there was a problem sending the request or decoding the
|
||||||
// response, or if the identity server responded with a non-OK status.
|
// response, or if the identity server responded with a non-OK status.
|
||||||
func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, error) {
|
func CreateSession(
|
||||||
|
ctx context.Context, req EmailAssociationRequest, cfg config.Dendrite,
|
||||||
|
) (string, error) {
|
||||||
if err := isTrusted(req.IDServer, cfg); err != nil {
|
if err := isTrusted(req.IDServer, cfg); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -71,7 +74,7 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
|
||||||
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
resp, err := client.Do(request)
|
resp, err := client.Do(request.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -97,13 +100,19 @@ func CreateSession(req EmailAssociationRequest, cfg config.Dendrite) (string, er
|
||||||
// identifier and its medium.
|
// identifier and its medium.
|
||||||
// Returns an error if there was a problem sending the request or decoding the
|
// Returns an error if there was a problem sending the request or decoding the
|
||||||
// response, or if the identity server responded with a non-OK status.
|
// response, or if the identity server responded with a non-OK status.
|
||||||
func CheckAssociation(creds Credentials, cfg config.Dendrite) (bool, string, string, error) {
|
func CheckAssociation(
|
||||||
|
ctx context.Context, creds Credentials, cfg config.Dendrite,
|
||||||
|
) (bool, string, string, error) {
|
||||||
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
if err := isTrusted(creds.IDServer, cfg); err != nil {
|
||||||
return false, "", "", err
|
return false, "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
|
url := fmt.Sprintf("https://%s/_matrix/identity/api/v1/3pid/getValidated3pid?sid=%s&client_secret=%s", creds.IDServer, creds.SID, creds.Secret)
|
||||||
resp, err := http.Get(url)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", "", err
|
return false, "", "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,7 +127,7 @@ func createRoom(req *http.Request, device *authtypes.Device,
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ func JoinRoomByIDOrAlias(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(req.Context(), localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,7 @@ func buildMembershipEvent(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := loadProfile(stateKey, cfg, accountDB)
|
profile, err := loadProfile(ctx, stateKey, cfg, accountDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,9 @@ func buildMembershipEvent(
|
||||||
// it if the user is local to this server, or returns an empty profile if not.
|
// it if the user is local to this server, or returns an empty profile if not.
|
||||||
// Returns an error if the retrieval failed or if the first parameter isn't a
|
// Returns an error if the retrieval failed or if the first parameter isn't a
|
||||||
// valid Matrix ID.
|
// valid Matrix ID.
|
||||||
func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Database) (*authtypes.Profile, error) {
|
func loadProfile(
|
||||||
|
ctx context.Context, userID string, cfg config.Dendrite, accountDB *accounts.Database,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -164,7 +166,7 @@ func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Databas
|
||||||
|
|
||||||
var profile *authtypes.Profile
|
var profile *authtypes.Profile
|
||||||
if serverName == cfg.Matrix.ServerName {
|
if serverName == cfg.Matrix.ServerName {
|
||||||
profile, err = accountDB.GetProfileByLocalpart(localpart)
|
profile, err = accountDB.GetProfileByLocalpart(ctx, localpart)
|
||||||
} else {
|
} else {
|
||||||
profile = &authtypes.Profile{}
|
profile = &authtypes.Profile{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package writers
|
package writers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -134,7 +135,9 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(accountDB, deviceDB, r.Username, r.Password)
|
return completeRegistration(
|
||||||
|
req.Context(), accountDB, deviceDB, r.Username, r.Password,
|
||||||
|
)
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 501,
|
Code: 501,
|
||||||
|
@ -143,7 +146,12 @@ func Register(req *http.Request, accountDB *accounts.Database, deviceDB *devices
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Database, username, password string) util.JSONResponse {
|
func completeRegistration(
|
||||||
|
ctx context.Context,
|
||||||
|
accountDB *accounts.Database,
|
||||||
|
deviceDB *devices.Database,
|
||||||
|
username, password string,
|
||||||
|
) util.JSONResponse {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 400,
|
Code: 400,
|
||||||
|
@ -157,7 +165,7 @@ func completeRegistration(accountDB *accounts.Database, deviceDB *devices.Databa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
acc, err := accountDB.CreateAccount(username, password)
|
acc, err := accountDB.CreateAccount(ctx, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 500,
|
Code: 500,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -68,7 +69,7 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = accountDB.CreateAccount(*username, *password)
|
_, err = accountDB.CreateAccount(context.Background(), *username, *password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
fmt.Println(err.Error())
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
|
@ -191,7 +191,7 @@ func createInviteFrom3PIDInvite(
|
||||||
StateKey: &inv.MXID,
|
StateKey: &inv.MXID,
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := accountDB.GetProfileByLocalpart(localpart)
|
profile, err := accountDB.GetProfileByLocalpart(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -29,6 +30,7 @@ const defaultTimelineLimit = 20
|
||||||
|
|
||||||
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
|
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
|
||||||
type syncRequest struct {
|
type syncRequest struct {
|
||||||
|
ctx context.Context
|
||||||
userID string
|
userID string
|
||||||
limit int
|
limit int
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
|
@ -47,6 +49,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
||||||
}
|
}
|
||||||
// TODO: Additional query params: set_presence, filter
|
// TODO: Additional query params: set_presence, filter
|
||||||
return &syncRequest{
|
return &syncRequest{
|
||||||
|
ctx: req.Context(),
|
||||||
userID: userID,
|
userID: userID,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
since: since,
|
since: since,
|
||||||
|
|
|
@ -128,7 +128,7 @@ func (rp *RequestPool) appendAccountData(
|
||||||
// already been sent. Instead, we send the whole batch.
|
// already been sent. Instead, we send the whole batch.
|
||||||
var global []gomatrixserverlib.ClientEvent
|
var global []gomatrixserverlib.ClientEvent
|
||||||
var rooms map[string][]gomatrixserverlib.ClientEvent
|
var rooms map[string][]gomatrixserverlib.ClientEvent
|
||||||
global, rooms, err = rp.accountDB.GetAccountData(localpart)
|
global, rooms, err = rp.accountDB.GetAccountData(req.ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -159,7 +159,9 @@ func (rp *RequestPool) appendAccountData(
|
||||||
events := []gomatrixserverlib.ClientEvent{}
|
events := []gomatrixserverlib.ClientEvent{}
|
||||||
// Request the missing data from the database
|
// Request the missing data from the database
|
||||||
for _, dataType := range dataTypes {
|
for _, dataType := range dataTypes {
|
||||||
evs, err := rp.accountDB.GetAccountDataByType(localpart, roomID, dataType)
|
evs, err := rp.accountDB.GetAccountDataByType(
|
||||||
|
req.ctx, localpart, roomID, dataType,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue