Add contexts to the roomserver storage layer (#229)

* Add contexts to the roomserver storage layer

* Fix rooms_table
This commit is contained in:
Mark Haines 2017-09-13 16:30:19 +01:00 committed by GitHub
parent 3133bef797
commit bfcce5bd21
21 changed files with 744 additions and 379 deletions

View file

@ -32,16 +32,16 @@ import (
type RoomserverAliasAPIDatabase interface {
// Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database.
SetRoomAlias(alias string, roomID string) error
SetRoomAlias(ctx context.Context, alias string, roomID string) error
// Look up the room ID a given alias refers to.
// Returns an error if there was a problem talking to the database.
GetRoomIDFromAlias(alias string) (string, error)
GetRoomIDFromAlias(ctx context.Context, alias string) (string, error)
// Look up all aliases referring to a given room ID.
// Returns an error if there was a problem talking to the database.
GetAliasesFromRoomID(roomID string) ([]string, error)
GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
// Remove a given room alias.
// Returns an error if there was a problem talking to the database.
RemoveRoomAlias(alias string) error
RemoveRoomAlias(ctx context.Context, alias string) error
}
// RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI
@ -59,7 +59,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
response *api.SetRoomAliasResponse,
) error {
// Check if the alias isn't already referring to a room
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil {
return err
}
@ -71,7 +71,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
response.AliasExists = false
// Save the new alias
if err := r.DB.SetRoomAlias(request.Alias, request.RoomID); err != nil {
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil {
return err
}
@ -93,7 +93,7 @@ func (r *RoomserverAliasAPI) GetAliasRoomID(
response *api.GetAliasRoomIDResponse,
) error {
// Look up the room ID in the database
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil {
return err
}
@ -109,18 +109,21 @@ func (r *RoomserverAliasAPI) RemoveRoomAlias(
response *api.RemoveRoomAliasResponse,
) error {
// Look up the room ID in the database
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil {
return err
}
// Remove the dalias from the database
if err := r.DB.RemoveRoomAlias(request.Alias); err != nil {
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
return err
}
// Send an updated m.room.aliases event
if err := r.sendUpdatedAliasesEvent(ctx, request.UserID, roomID); err != nil {
// At this point we've already committed the alias to the database so we
// shouldn't cancel this request.
// TODO: Ensure that we send unsent events when if server restarts.
if err := r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID); err != nil {
return err
}
@ -147,7 +150,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
// Retrieve the updated list of aliases, marhal it and set it as the
// event's content
aliases, err := r.DB.GetAliasesFromRoomID(roomID)
aliases, err := r.DB.GetAliasesFromRoomID(ctx, roomID)
if err != nil {
return err
}

View file

@ -15,16 +15,23 @@
package input
import (
"context"
"sort"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"sort"
)
// checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) {
func checkAuthEvents(
ctx context.Context,
db RoomEventDatabase,
event gomatrixserverlib.Event,
authEventIDs []string,
) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
if err != nil {
return nil, err
}
@ -34,7 +41,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return nil, err
}
@ -84,7 +91,10 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
}
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
if !ok {
return nil
}
@ -100,7 +110,10 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
if !ok {
return nil
}
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: stateKeyNID,
})
if !ok {
return nil
}
@ -113,6 +126,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
// loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents(
ctx context.Context,
db RoomEventDatabase,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
@ -121,7 +135,7 @@ func loadAuthEvents(
var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
return
}
@ -135,34 +149,52 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(eventNIDs); err != nil {
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
return
}
return
}
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
func stateKeyTuplesNeeded(
stateKeyNIDMap map[string]types.EventStateKeyNID,
stateNeeded gomatrixserverlib.StateNeeded,
) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
}
}
for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
}
}
return keyTuples

View file

@ -15,6 +15,7 @@
package input
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/common"
@ -28,22 +29,38 @@ import (
type RoomEventDatabase interface {
state.RoomStateDatabase
// Stores a matrix room event in the database
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
StoreEvent(
ctx context.Context,
event gomatrixserverlib.Event,
authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database.
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error)
// Set the state at an event.
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SetState(
ctx context.Context,
eventNID types.EventNID,
stateNID types.StateSnapshotNID,
) error
// Look up the latest events in a room in preparation for an update.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required.
GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error)
GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
) (updater types.RoomRecentEventsUpdater, err error)
// Look up the string event IDs for a list of numeric event IDs
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error)
// Build a membership updater for the target user in a room.
MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error)
MembershipUpdater(
ctx context.Context, roomID, targerUserID string,
) (types.MembershipUpdater, error)
}
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
@ -52,18 +69,23 @@ type OutputRoomEventWriter interface {
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
}
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error {
func processRoomEvent(
ctx context.Context,
db RoomEventDatabase,
ow OutputRoomEventWriter,
input api.InputRoomEvent,
) error {
// Parse and validate the event JSON
event := input.Event
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
if err != nil {
return err
}
// Store the event
roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs)
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, authEventNIDs)
if err != nil {
return err
}
@ -82,20 +104,20 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil {
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err
}
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil {
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
return nil
}
} else {
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil {
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
return err
}
}
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}
if input.Kind == api.KindBackfill {
@ -104,14 +126,19 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
}
// Update the extremities of the event graph for the room
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
if err := updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
return err
}
return nil
}
func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) {
func processInviteEvent(
ctx context.Context,
db RoomEventDatabase,
ow OutputRoomEventWriter,
input api.InputInviteEvent,
) (err error) {
if input.Event.StateKey() == nil {
return fmt.Errorf("invite must be a state event")
}
@ -119,7 +146,7 @@ func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input ap
roomID := input.Event.RoomID()
targetUserID := *input.Event.StateKey()
updater, err := db.MembershipUpdater(roomID, targetUserID)
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID)
if err != nil {
return err
}

View file

@ -59,12 +59,12 @@ func (r *RoomserverInputAPI) InputRoomEvents(
response *api.InputRoomEventsResponse,
) error {
for i := range request.InputRoomEvents {
if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil {
if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
return err
}
}
for i := range request.InputInviteEvents {
if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil {
if err := processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
return err
}
}

View file

@ -16,6 +16,7 @@ package input
import (
"bytes"
"context"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
@ -42,6 +43,7 @@ import (
// 7 <----- latest
//
func updateLatestEvents(
ctx context.Context,
db RoomEventDatabase,
ow OutputRoomEventWriter,
roomNID types.RoomNID,
@ -49,7 +51,7 @@ func updateLatestEvents(
event gomatrixserverlib.Event,
sendAsServer string,
) (err error) {
updater, err := db.GetLatestEventsForUpdate(roomNID)
updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil {
return
}
@ -57,7 +59,7 @@ func updateLatestEvents(
defer common.EndTransaction(updater, &succeeded)
u := latestEventsUpdater{
db: db, updater: updater, ow: ow, roomNID: roomNID,
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
}
if err = u.doUpdateLatestEvents(); err != nil {
@ -73,6 +75,7 @@ func updateLatestEvents(
// The state could be passed using function arguments, but it becomes impractical
// when there are so many variables to pass around.
type latestEventsUpdater struct {
ctx context.Context
db RoomEventDatabase
updater types.RoomRecentEventsUpdater
ow OutputRoomEventWriter
@ -133,7 +136,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
return err
}
updates, err := updateMemberships(u.db, u.updater, u.removed, u.added)
updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
if err != nil {
return err
}
@ -174,18 +177,22 @@ func (u *latestEventsUpdater) latestState() error {
for i := range u.latest {
latestStateAtEvents[i] = u.latest[i].StateAtEvent
}
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(u.db, u.roomNID, latestStateAtEvents)
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(
u.ctx, u.db, u.roomNID, latestStateAtEvents,
)
if err != nil {
return err
}
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(u.db, u.oldStateNID, u.newStateNID)
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(
u.ctx, u.db, u.oldStateNID, u.newStateNID,
)
if err != nil {
return err
}
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
)
if err != nil {
return err
@ -193,7 +200,12 @@ func (u *latestEventsUpdater) latestState() error {
return nil
}
func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference {
func calculateLatest(
oldLatest []types.StateAtEventAndReference,
alreadyReferenced bool,
prevEvents []gomatrixserverlib.EventReference,
newEvent types.StateAtEventAndReference,
) []types.StateAtEventAndReference {
var alreadyInLatest bool
var newLatest []types.StateAtEventAndReference
for _, l := range oldLatest {
@ -253,7 +265,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
eventIDMap, err := u.db.EventIDs(stateEventNIDs)
eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package input
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
@ -27,7 +28,10 @@ import (
// Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state.
func updateMemberships(
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry,
ctx context.Context,
db RoomEventDatabase,
updater types.RoomRecentEventsUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID
@ -43,7 +47,7 @@ func updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
events, err := db.Events(eventNIDs)
events, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}

View file

@ -33,35 +33,47 @@ type RoomserverQueryAPIDatabase interface {
// Look up the numeric ID for the room.
// Returns 0 if the room doesn't exists.
// Returns an error if there was a problem talking to the database.
RoomNID(roomID string) (types.RoomNID, error)
RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
// Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database.
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
// Look up the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database.
EventNIDs(eventIDs []string) (map[string]types.EventNID, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
// Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong.
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Lookup the membership of a given user in a given room.
// Returns the numeric ID of the latest membership event sent from this user
// in this room, along a boolean set to true if the user is still in this room,
// false if not.
// Returns an error if there was a problem talking to the database.
GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembership(
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error)
// Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true.
// Returns an error if there was a problem talking to the database.
GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
) ([]types.EventNID, error)
// Look up the active invites targeting a user in a room and return the
// numeric state key IDs for the user IDs who sent them.
// Returns an error if there was a problem talking to the database.
GetInvitesForUser(roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserNIDs []types.EventStateKeyNID, err error)
GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserNIDs []types.EventStateKeyNID, err error)
// Look up the string event state keys for a list of numeric event state keys
// Returns an error if there was a problem talking to the database.
EventStateKeys([]types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
EventStateKeys(
context.Context, []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error)
}
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -76,7 +88,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
response *api.QueryLatestEventsAndStateResponse,
) error {
response.QueryLatestEventsAndStateRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID)
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
@ -85,18 +97,21 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
}
response.RoomExists = true
var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID)
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, roomNID)
if err != nil {
return err
}
// Look up the currrent state for the requested tuples.
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch)
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(
ctx, r.DB, currentStateSnapshotNID, request.StateToFetch,
)
if err != nil {
return err
}
stateEvents, err := r.loadStateEvents(stateEntries)
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}
@ -112,7 +127,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
response *api.QueryStateAfterEventsResponse,
) error {
response.QueryStateAfterEventsRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID)
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
@ -121,7 +136,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
}
response.RoomExists = true
prevStates, err := r.DB.StateAtEventIDs(request.PrevEventIDs)
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
switch err.(type) {
case types.MissingEventError:
@ -133,12 +148,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
response.PrevEventsExist = true
// Look up the currrent state for the requested tuples.
stateEntries, err := state.LoadStateAfterEventsForStringTuples(r.DB, prevStates, request.StateToFetch)
stateEntries, err := state.LoadStateAfterEventsForStringTuples(
ctx, r.DB, prevStates, request.StateToFetch,
)
if err != nil {
return err
}
stateEvents, err := r.loadStateEvents(stateEntries)
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}
@ -155,7 +172,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
) error {
response.QueryEventsByIDRequest = *request
eventNIDMap, err := r.DB.EventNIDs(request.EventIDs)
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
if err != nil {
return err
}
@ -165,7 +182,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
eventNIDs = append(eventNIDs, nid)
}
events, err := r.loadEvents(eventNIDs)
events, err := r.loadEvents(ctx, eventNIDs)
if err != nil {
return err
}
@ -174,16 +191,20 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
return nil
}
func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) {
func (r *RoomserverQueryAPI) loadStateEvents(
ctx context.Context, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return r.loadEvents(eventNIDs)
return r.loadEvents(ctx, eventNIDs)
}
func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(eventNIDs)
func (r *RoomserverQueryAPI) loadEvents(
ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
@ -201,12 +222,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
roomNID, err := r.DB.RoomNID(request.RoomID)
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
membershipEventNID, stillInRoom, err := r.DB.GetMembership(roomNID, request.Sender)
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
if err != nil {
return nil
}
@ -223,14 +244,14 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
var events []types.Event
if stillInRoom {
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(roomNID, request.JoinedOnly)
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
if err != nil {
return err
}
events, err = r.DB.Events(eventNIDs)
events, err = r.DB.Events(ctx, eventNIDs)
} else {
events, err = r.getMembershipsBeforeEventNID(membershipEventNID, request.JoinedOnly)
events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly)
}
if err != nil {
@ -249,22 +270,24 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
// of the event's room as it was when this event was fired, then filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNID, joinedOnly bool) ([]types.Event, error) {
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
) ([]types.Event, error) {
events := []types.Event{}
// Lookup the event NID
eIDs, err := r.DB.EventIDs([]types.EventNID{eventNID})
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
eventIDs := []string{eIDs[eventNID]}
prevState, err := r.DB.StateAtEventIDs(eventIDs)
prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
// Fetch the state as it was when this event was fired
stateEntries, err := state.LoadCombinedStateAfterEvents(r.DB, prevState)
stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState)
if err != nil {
return nil, err
}
@ -278,7 +301,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
}
// Get all of the events in this state
stateEvents, err := r.DB.Events(eventNIDs)
stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
@ -304,27 +327,27 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
// QueryInvitesForUser implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryInvitesForUser(
_ context.Context,
ctx context.Context,
request *api.QueryInvitesForUserRequest,
response *api.QueryInvitesForUserResponse,
) error {
roomNID, err := r.DB.RoomNID(request.RoomID)
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
targetUserNIDs, err := r.DB.EventStateKeyNIDs([]string{request.TargetUserID})
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
if err != nil {
return err
}
targetUserNID := targetUserNIDs[request.TargetUserID]
senderUserNIDs, err := r.DB.GetInvitesForUser(roomNID, targetUserNID)
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
if err != nil {
return err
}
senderUserIDs, err := r.DB.EventStateKeys(senderUserNIDs)
senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
if err != nil {
return err
}
@ -342,14 +365,14 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) error {
stateEntries, err := state.LoadStateAtEvent(r.DB, request.EventID)
stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, request.EventID)
if err != nil {
return err
}
// TODO: We probably want to make it so that we don't have to pull
// out all the state if possible.
stateAtEvent, err := r.loadStateEvents(stateEntries)
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}

View file

@ -17,6 +17,7 @@
package state
import (
"context"
"fmt"
"sort"
"time"
@ -30,49 +31,58 @@ import (
// A RoomStateDatabase has the storage APIs needed to load state from the database
type RoomStateDatabase interface {
// Store the room state at an event in the database
AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error)
// Look up the state of a room at each event for a list of string event IDs.
// Returns an error if there is an error talking to the database
// Returns a types.MissingEventError if the room state for the event IDs aren't in the database
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
// Look up the numeric IDs for a list of string event types.
// Returns a map from string event type to numeric ID for the event type.
EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
// Look up the numeric IDs for a list of string event state keys.
// Returns a map from string state key to numeric ID for the state key.
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
// Look up the numeric state data IDs for each numeric state snapshot ID
// The returned slice is sorted by numeric state snapshot ID.
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
// Look up the state data for each numeric state data ID
// The returned slice is sorted by numeric state data ID.
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
// Look up the state data for the state key tuples for each numeric state block ID
// This is used to fetch a subset of the room state at a snapshot.
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
// The returned slice is sorted by numeric state block ID.
StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) (
[]types.StateEntryList, error,
)
StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
Events(eventNIDs []types.EventNID) ([]types.Event, error)
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
}
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) {
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
func LoadStateAtSnapshot(
ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil {
return nil, err
}
@ -100,13 +110,15 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
}
// LoadStateAtEvent loads the full state of a room at a particular event.
func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, error) {
snapshotNID, err := db.SnapshotNIDFromEventID(eventID)
func LoadStateAtEvent(
ctx context.Context, db RoomStateDatabase, eventID string,
) ([]types.StateEntry, error) {
snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, err
}
stateEntries, err := LoadStateAtSnapshot(db, snapshotNID)
stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID)
if err != nil {
return nil, err
}
@ -116,7 +128,9 @@ func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry,
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list.
func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
func LoadCombinedStateAfterEvents(
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) {
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID
@ -125,7 +139,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because the events could be state events where
// the snapshot of the room state before them was the same.
stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs))
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
if err != nil {
return nil, err
}
@ -138,7 +152,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because a block of state entries could be reused by
// multiple snapshots.
stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs))
stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
if err != nil {
return nil, err
}
@ -186,9 +200,9 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
}
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) (
removed, added []types.StateEntry, err error,
) {
func DifferenceBetweeenStateSnapshots(
ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error) {
if oldStateNID == newStateNID {
// If the snapshot NIDs are the same then nothing has changed
return nil, nil, nil
@ -197,13 +211,13 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
var oldEntries []types.StateEntry
var newEntries []types.StateEntry
if oldStateNID != 0 {
oldEntries, err = LoadStateAtSnapshot(db, oldStateNID)
oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID)
if err != nil {
return nil, nil, err
}
}
if newStateNID != 0 {
newEntries, err = LoadStateAtSnapshot(db, newStateNID)
newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID)
if err != nil {
return nil, nil, err
}
@ -246,19 +260,26 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAtSnapshotForStringTuples(
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
ctx context.Context,
db RoomStateDatabase,
stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) {
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
if err != nil {
return nil, err
}
return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples)
return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples)
}
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
// Returns an error if there was a problem talking to the database.
func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateKeyTuple, error) {
func stringTuplesToNumericTuples(
ctx context.Context,
db RoomStateDatabase,
stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) {
eventTypes := make([]string, len(stringTuples))
stateKeys := make([]string, len(stringTuples))
for i := range stringTuples {
@ -266,12 +287,12 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
stateKeys[i] = stringTuples[i].StateKey
}
eventTypes = util.UniqueStrings(eventTypes)
eventTypeMap, err := db.EventTypeNIDs(eventTypes)
eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes)
if err != nil {
return nil, err
}
stateKeys = util.UniqueStrings(stateKeys)
stateKeyMap, err := db.EventStateKeyNIDs(stateKeys)
stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys)
if err != nil {
return nil, err
}
@ -297,16 +318,21 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
// This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func loadStateAtSnapshotForNumericTuples(
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple,
ctx context.Context,
db RoomStateDatabase,
stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples)
stateEntryLists, err := db.StateEntriesForTuples(
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
)
if err != nil {
return nil, err
}
@ -341,23 +367,29 @@ func loadStateAtSnapshotForNumericTuples(
// This is typically the state before an event.
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAfterEventsForStringTuples(
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
ctx context.Context,
db RoomStateDatabase,
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) {
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
if err != nil {
return nil, err
}
return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples)
return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples)
}
func loadStateAfterEventsForNumericTuples(
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple,
ctx context.Context,
db RoomStateDatabase,
prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
if len(prevStates) == 1 {
// Fast path for a single event.
prevState := prevStates[0]
result, err := loadStateAtSnapshotForNumericTuples(
db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
)
if err != nil {
return nil, err
@ -390,7 +422,7 @@ func loadStateAfterEventsForNumericTuples(
// TODO: Add metrics for this as it could take a long time for big rooms
// with large conflicts.
fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates)
fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates)
if err != nil {
return nil, err
}
@ -403,7 +435,10 @@ func loadStateAfterEventsForNumericTuples(
for _, tuple := range stateKeyTuples {
eventNID, ok := stateEntryMap(fullState).lookup(tuple)
if ok {
result = append(result, types.StateEntry{tuple, eventNID})
result = append(result, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
})
}
}
sort.Sort(stateEntrySorter(result))
@ -509,7 +544,10 @@ func init() {
// Stores the snapshot of the state in the database.
// Returns a numeric ID for the snapshot of the state before the event.
func CalculateAndStoreStateBeforeEvent(
db RoomStateDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
ctx context.Context,
db RoomStateDatabase,
event gomatrixserverlib.Event,
roomNID types.RoomNID,
) (types.StateSnapshotNID, error) {
// Load the state at the prev events.
prevEventRefs := event.PrevEvents()
@ -518,25 +556,30 @@ func CalculateAndStoreStateBeforeEvent(
prevEventIDs[i] = prevEventRefs[i].EventID
}
prevStates, err := db.StateAtEventIDs(prevEventIDs)
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil {
return 0, err
}
// The state before this event will be the state after the events that came before it.
return CalculateAndStoreStateAfterEvents(db, roomNID, prevStates)
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
}
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
func CalculateAndStoreStateAfterEvents(
ctx context.Context,
db RoomStateDatabase,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) {
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 {
// 2) There weren't any prev_events for this event so the state is
// empty.
metrics.algorithm = "empty_state"
return metrics.stop(db.AddState(roomNID, nil, nil))
return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
}
if len(prevStates) == 1 {
@ -551,7 +594,9 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
}
// The previous event was a state event so we need to store a copy
// of the previous state updated with that event.
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID})
stateBlockNIDLists, err := db.StateBlockNIDs(
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
)
if err != nil {
metrics.algorithm = "_load_state_blocks"
return metrics.stop(0, err)
@ -562,14 +607,14 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
// add the state event as a block of size one to the end of the blocks.
metrics.algorithm = "single_delta"
return metrics.stop(db.AddState(
roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
))
}
// If there are too many deltas then we need to calculate the full state
// So fall through to calculateAndStoreStateAfterManyEvents
}
return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates, metrics)
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
}
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
@ -583,10 +628,15 @@ const maxStateBlockNIDs = 64
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
// Stores the resulting state and returns a numeric ID for the snapshot.
func calculateAndStoreStateAfterManyEvents(
db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics,
ctx context.Context,
db RoomStateDatabase,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) {
state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates)
state, algorithm, conflictLength, err :=
calculateStateAfterManyEvents(ctx, db, prevStates)
metrics.algorithm = algorithm
if err != nil {
return metrics.stop(0, err)
@ -596,16 +646,16 @@ func calculateAndStoreStateAfterManyEvents(
// previous state.
metrics.conflictLength = conflictLength
metrics.fullStateLength = len(state)
return metrics.stop(db.AddState(roomNID, nil, state))
return metrics.stop(db.AddState(ctx, roomNID, nil, state))
}
func calculateStateAfterManyEvents(
db RoomStateDatabase, prevStates []types.StateAtEvent,
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
var combined []types.StateEntry
// Conflict resolution.
// First stage: load the state after each of the prev events.
combined, err = LoadCombinedStateAfterEvents(db, prevStates)
combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
if err != nil {
algorithm = "_load_combined_state"
return
@ -635,7 +685,7 @@ func calculateStateAfterManyEvents(
}
var resolved []types.StateEntry
resolved, err = resolveConflicts(db, notConflicted, conflicts)
resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
if err != nil {
algorithm = "_resolve_conflicts"
return
@ -657,10 +707,14 @@ func calculateStateAfterManyEvents(
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
// The returned list is sorted by state key tuple.
// Returns an error if there was a problem talking to the database.
func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) {
func resolveConflicts(
ctx context.Context,
db RoomStateDatabase,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
// Load the conflicted events
conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted)
conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted)
if err != nil {
return nil, err
}
@ -672,7 +726,7 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys)
stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys)
if err != nil {
return nil, err
}
@ -682,10 +736,13 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
var authEntries []types.StateEntry
for _, tuple := range tuplesNeeded {
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
authEntries = append(authEntries, types.StateEntry{tuple, eventNID})
authEntries = append(authEntries, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
})
}
}
authEvents, _, err := loadStateEvents(db, authEntries)
authEvents, _, err := loadStateEvents(ctx, db, authEntries)
if err != nil {
return nil, err
}
@ -711,24 +768,39 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
}
}
for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
}
}
return keyTuples
@ -738,12 +810,14 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stat
// Returns a list of state events in no particular order and a map from string event ID back to state entry.
// The map can be used to recover which numeric state entry a given event is for.
// Returns an error if there was a problem talking to the database.
func loadStateEvents(db RoomStateDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
func loadStateEvents(
ctx context.Context, db RoomStateDatabase, entries []types.StateEntry,
) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventNIDs := make([]types.EventNID, len(entries))
for i := range entries {
eventNIDs[i] = entries[i].EventNID
}
events, err := db.Events(eventNIDs)
events, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, nil, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/roomserver/types"
@ -65,8 +66,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
_, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON)
func (s *eventJSONStatements) insertEventJSON(
ctx context.Context, eventNID types.EventNID, eventJSON []byte,
) error {
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
return err
}
@ -75,8 +78,10 @@ type eventJSONPair struct {
EventJSON []byte
}
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs))
func (s *eventJSONStatements) bulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID,
) ([]eventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
@ -91,20 +92,30 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
func (s *eventStateKeyStatements) insertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
func (s *eventStateKeyStatements) selectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
ctx, pq.StringArray(eventStateKeys),
)
if err != nil {
return nil, err
}
@ -122,18 +133,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
return result, nil
}
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
func (s *eventStateKeyStatements) selectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID,
) (string, error) {
var eventStateKey string
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
stmt := common.TxStmt(txn, s.selectEventStateKeyStmt)
err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
var nIDs pq.Int64Array
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs)
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
@ -107,20 +108,26 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
func (s *eventTypeStatements) insertEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
}
func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) {
func (s *eventTypeStatements) selectEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
}
func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes))
func (s *eventTypeStatements) bulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"fmt"
@ -154,7 +155,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
}
func (s *eventStatements) insertEvent(
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
ctx context.Context,
roomNID types.RoomNID,
eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID,
eventID string,
referenceSHA256 []byte,
authEventNIDs []types.EventNID,
@ -162,24 +166,28 @@ func (s *eventStatements) insertEvent(
) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64
var stateNID int64
err := s.insertEventStmt.QueryRow(
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
eventNIDsAsArray(authEventNIDs), depth,
err := s.insertEventStmt.QueryRowContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) {
func (s *eventStatements) selectEvent(
ctx context.Context, eventID string,
) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64
var stateNID int64
err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID)
err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
func (s *eventStatements) bulkSelectStateEventByID(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -216,8 +224,10 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs))
func (s *eventStatements) bulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -248,28 +258,40 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types
return results, err
}
func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID))
func (s *eventStatements) updateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
return err
}
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
func (s *eventStatements) selectEventSentToOutput(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (sentToOutput bool, err error) {
stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
return
}
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := stmt.ExecContext(ctx, int64(eventNID))
return err
}
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
func (s *eventStatements) selectEventID(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (eventID string, err error) {
stmt := common.TxStmt(txn, s.selectEventIDStmt)
err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
return
}
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
func (s *eventStatements) bulkSelectStateAtEventAndReference(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {
stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
@ -304,8 +326,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
return results, nil
}
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) {
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs))
func (s *eventStatements) bulkSelectEventReference(
ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.EventReference, error) {
rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
@ -325,8 +349,8 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) (
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs))
func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
@ -349,8 +373,8 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs))
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -367,9 +391,10 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type
return results, nil
}
func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) {
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
var result int64
err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result)
stmt := s.selectMaxEventDepthStmt
err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
if err != nil {
return 0, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
@ -91,12 +92,13 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
}
func (s *inviteStatements) insertInviteEvent(
ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
if err != nil {
return false, err
@ -109,9 +111,11 @@ func (s *inviteStatements) insertInviteEvent(
}
func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) {
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return nil, err
}
@ -129,10 +133,11 @@ func (s *inviteStatements) updateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) selectInviteActiveForUserInRoom(
ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.Query(
targetUserNID, roomNID,
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {
return nil, err

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
@ -114,34 +115,38 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
}
func (s *membershipStatements) insertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
stmt := common.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
return err
}
func (s *membershipStatements) selectMembershipForUpdate(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) {
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
roomNID, targetUserNID,
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership)
return
}
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRow(
roomNID, targetUserNID,
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID)
return
}
func (s *membershipStatements) selectMembershipsFromRoom(
roomNID types.RoomNID,
ctx context.Context, roomNID types.RoomNID,
) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID)
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
if err != nil {
return
}
@ -156,9 +161,11 @@ func (s *membershipStatements) selectMembershipsFromRoom(
return
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership membershipState,
) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership)
stmt := s.selectMembershipsFromRoomAndMembershipStmt
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
}
@ -174,12 +181,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
}
func (s *membershipStatements) updateMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID,
) error {
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
roomNID, targetUserNID, senderUserNID, membership, eventNID,
_, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
)
return err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
@ -73,14 +74,26 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
func (s *previousEventStatements) insertPreviousEvent(
ctx context.Context,
txn *sql.Tx,
previousEventID string,
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
)
return err
}
// Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
func (s *previousEventStatements) selectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var ok int64
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
)
@ -62,22 +63,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *roomAliasesStatements) insertRoomAlias(alias string, roomID string) (err error) {
_, err = s.insertRoomAliasStmt.Exec(alias, roomID)
func (s *roomAliasesStatements) insertRoomAlias(
ctx context.Context, alias string, roomID string,
) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID)
return
}
func (s *roomAliasesStatements) selectRoomIDFromAlias(alias string) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRow(alias).Scan(&roomID)
func (s *roomAliasesStatements) selectRoomIDFromAlias(
ctx context.Context, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
return
}
func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases []string, err error) {
func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string,
) (aliases []string, err error) {
aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.Query(roomID)
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil {
return
}
@ -94,7 +101,9 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases
return
}
func (s *roomAliasesStatements) deleteRoomAlias(alias string) (err error) {
_, err = s.deleteRoomAliasStmt.Exec(alias)
func (s *roomAliasesStatements) deleteRoomAlias(
ctx context.Context, alias string,
) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
return
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
@ -81,22 +82,31 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
func (s *roomStatements) insertRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
stmt := common.TxStmt(txn, s.insertRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
func (s *roomStatements) selectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
}
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) {
func (s *roomStatements) selectLatestEventNIDs(
ctx context.Context, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var stateSnapshotNID int64
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID)
stmt := s.selectLatestEventNIDsStmt
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil {
return nil, 0, err
}
@ -107,13 +117,14 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
}
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
[]types.EventNID, types.EventNID, types.StateSnapshotNID, error,
) {
func (s *roomStatements) selectLatestEventsNIDsForUpdate(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var lastEventSentNID int64
var stateSnapshotNID int64
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
if err != nil {
return nil, 0, 0, err
}
@ -125,11 +136,20 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
}
func (s *roomStatements) updateLatestEventNIDs(
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventNIDs []types.EventNID,
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext(
ctx,
roomNID,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
)
return err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"fmt"
"sort"
@ -97,9 +98,14 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
func (s *stateBlockStatements) bulkInsertStateData(
ctx context.Context,
stateBlockNID types.StateBlockNID,
entries []types.StateEntry,
) error {
for _, entry := range entries {
_, err := s.insertStateDataStmt.Exec(
_, err := s.insertStateDataStmt.ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
@ -112,18 +118,22 @@ func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBloc
return nil
}
func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) {
func (s *stateBlockStatements) selectNextStateBlockNID(
ctx context.Context,
) (types.StateBlockNID, error) {
var stateBlockNID int64
err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID)
err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
return types.StateBlockNID(stateBlockNID), err
}
func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids))
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}
@ -165,15 +175,20 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []type
}
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query(
stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray,
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
ctx,
stateBlockNIDsAsArray(stateBlockNIDs),
eventTypeNIDArray,
eventStateKeyNIDArray,
)
if err != nil {
return nil, err

View file

@ -15,29 +15,30 @@
package storage
import (
"github.com/matrix-org/dendrite/roomserver/types"
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{1, 2},
{1, 4},
{2, 2},
{1, 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{1, 1},
{1, 2},
{1, 4},
{2, 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{0, 0},
{1, 3},
{2, 1},
{3, 1},
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
"fmt"
@ -74,21 +75,25 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
func (s *stateSnapshotStatements) insertState(
ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return
}
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids))
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package storage
import (
"context"
"database/sql"
// Import the postgres database driver.
@ -43,7 +44,9 @@ func Open(dataSourceName string) (*Database, error) {
}
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) {
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
@ -53,11 +56,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
err error
)
if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil {
if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil {
return 0, types.StateAtEvent{}, err
}
if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil {
if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
return 0, types.StateAtEvent{}, err
}
@ -65,12 +68,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(nil, *eventStateKey); err != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
return 0, types.StateAtEvent{}, err
}
}
if eventNID, stateNID, err = d.statements.insertEvent(
ctx,
roomNID,
eventTypeNID,
eventStateKeyNID,
@ -81,14 +85,14 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.statements.selectEvent(event.EventID())
eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
}
if err != nil {
return 0, types.StateAtEvent{}, err
}
}
if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil {
if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
return 0, types.StateAtEvent{}, err
}
@ -104,76 +108,94 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
}, nil
}
func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database.
roomNID, err := d.statements.selectRoomNID(txn, roomID)
roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(txn, roomID)
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
roomNID, err = d.statements.selectRoomNID(txn, roomID)
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
}
}
return roomNID, err
}
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
func (d *Database) assignEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
// Check if we already have a numeric ID in the database.
eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.statements.insertEventTypeNID(eventType)
eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.statements.selectEventTypeNID(eventType)
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType)
}
}
return eventTypeNID, err
}
func (d *Database) assignStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(txn, eventStateKey)
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(txn, eventStateKey)
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(txn, eventStateKey)
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
return eventStateKeyNID, err
}
// StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
return d.statements.bulkSelectStateEventByID(eventIDs)
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
}
// EventTypeNIDs implements state.RoomStateDatabase
func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) {
return d.statements.bulkSelectEventTypeNID(eventTypes)
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.statements.bulkSelectEventTypeNID(ctx, eventTypes)
}
// EventStateKeyNIDs implements state.RoomStateDatabase
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys)
}
// EventStateKeys implements query.RoomserverQueryAPIDatabase
func (d *Database) EventStateKeys(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
return d.statements.bulkSelectEventStateKey(eventStateKeyNIDs)
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs)
}
// EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) {
return d.statements.bulkSelectEventNID(eventIDs)
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.statements.bulkSelectEventNID(ctx, eventIDs)
}
// Events implements input.EventDatabase
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
if err != nil {
return nil, err
}
@ -191,78 +213,98 @@ func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
}
// AddState implements input.EventDatabase
func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) {
func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error) {
if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID()
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
if err != nil {
return 0, err
}
if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil {
if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
return 0, err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
return d.statements.insertState(roomNID, stateBlockNIDs)
return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
}
// SetState implements input.EventDatabase
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
return d.statements.updateEventState(eventNID, stateNID)
func (d *Database) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return d.statements.updateEventState(ctx, eventNID, stateNID)
}
// StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) {
return d.statements.bulkSelectStateAtEventByID(eventIDs)
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
}
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
}
// StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs)
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
}
// SnapshotNIDFromEventID implements state.RoomStateDatabase
func (d *Database) SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) {
_, stateNID, err := d.statements.selectEvent(eventID)
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.statements.selectEvent(ctx, eventID)
return stateNID, err
}
// EventIDs implements input.RoomEventDatabase
func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
return d.statements.bulkSelectEventID(eventNIDs)
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.statements.bulkSelectEventID(ctx, eventNIDs)
}
// GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) {
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
) (types.RoomRecentEventsUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
return nil, err
}
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID)
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
if err != nil {
txn.Rollback()
return nil, err
}
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs)
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
txn.Rollback()
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent)
lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
txn.Rollback()
return nil, err
}
}
return &roomRecentEventsUpdater{
transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
@ -293,7 +335,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
// StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err
}
}
@ -302,7 +344,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256)
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
@ -321,26 +363,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.statements.selectEventSentToOutput(u.txn, eventNID)
return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.statements.updateEventSentToOutput(u.txn, eventNID)
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
}
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetUserNID)
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
}
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(nil, roomID)
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
return 0, nil
}
@ -348,16 +390,18 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID)
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
if err != nil {
return nil, 0, 0, err
}
references, err := d.statements.bulkSelectEventReference(eventNIDs)
references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
if err != nil {
return nil, 0, 0, err
}
depth, err := d.statements.selectMaxEventDepth(eventNIDs)
depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
if err != nil {
return nil, 0, 0, err
}
@ -366,40 +410,48 @@ func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.Ev
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.statements.selectInviteActiveForUserInRoom(targetUserNID, roomNID)
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
// SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(alias string, roomID string) error {
return d.statements.insertRoomAlias(alias, roomID)
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error {
return d.statements.insertRoomAlias(ctx, alias, roomID)
}
// GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDFromAlias(alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(alias)
func (d *Database) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(ctx, alias)
}
// GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesFromRoomID(roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(roomID)
func (d *Database) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(ctx, roomID)
}
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(alias string) error {
return d.statements.deleteRoomAlias(alias)
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(ctx, alias)
}
// StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples(
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples)
return d.statements.bulkSelectFilteredStateBlockEntries(
ctx, stateBlockNIDs, stateKeyTuples,
)
}
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) {
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
) (types.MembershipUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
return nil, err
@ -411,17 +463,17 @@ func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.Members
}
}()
roomNID, err := d.assignRoomNID(txn, roomID)
roomNID, err := d.assignRoomNID(ctx, txn, roomID)
if err != nil {
return nil, err
}
targetUserNID, err := d.assignStateKeyNID(txn, targetUserID)
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
if err != nil {
return nil, err
}
updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID)
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
if err != nil {
return nil, err
}
@ -439,20 +491,23 @@ type membershipUpdater struct {
}
func (d *Database) membershipUpdaterTxn(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(txn, roomNID, targetUserNID); err != nil {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
return nil, err
}
membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetUserNID)
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil {
return nil, err
}
return &membershipUpdater{
transaction{txn}, d, roomNID, targetUserNID, membership,
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
}, nil
}
@ -473,19 +528,19 @@ func (u *membershipUpdater) IsLeave() bool {
// SetToInvite implements types.MembershipUpdater
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, event.Sender())
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
if err != nil {
return false, err
}
inserted, err := u.d.statements.insertInviteEvent(
u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return false, err
}
if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil {
return false, err
}
@ -497,7 +552,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
return nil, err
}
@ -505,7 +560,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired(
u.txn, u.roomNID, u.targetUserNID,
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
@ -513,14 +568,15 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
}
// Look up the NID of the new join event
nIDs, err := u.d.EventNIDs([]string{eventID})
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID],
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID],
); err != nil {
return nil, err
}
@ -531,26 +587,27 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// SetToLeave implements types.MembershipUpdater
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
return nil, err
}
inviteEventIDs, err := u.d.statements.updateInviteRetired(
u.txn, u.roomNID, u.targetUserNID,
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return nil, err
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs([]string{eventID})
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
return nil, err
}
if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID],
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return nil, err
}
@ -559,19 +616,18 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
}
// GetMembership implements query.RoomserverQueryAPIDB
func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
txn, err := d.db.Begin()
if err != nil {
return
}
defer txn.Commit()
requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID)
func (d *Database) GetMembership(
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
if err != nil {
return
}
senderMembershipEventNID, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID)
senderMembershipEventNID, senderMembership, err :=
d.statements.selectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows {
// The user has never been a member of that room
return 0, false, nil
@ -583,15 +639,20 @@ func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID stri
}
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) {
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin)
return d.statements.selectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin,
)
}
return d.statements.selectMembershipsFromRoom(roomNID)
return d.statements.selectMembershipsFromRoom(ctx, roomNID)
}
type transaction struct {
ctx context.Context
txn *sql.Tx
}