diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index fe4295213..a6e39821a 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -183,7 +183,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res // Process the events. for _, e := range pdus { - if err := t.processEvent(ctx, e.Unwrap(), true); err != nil { + if err := t.processEvent(ctx, e.Unwrap()); err != nil { // If the error is due to the event itself being bad then we skip // it and move onto the next event. We report an error so that the // sender knows that we have skipped processing it. @@ -338,11 +338,15 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli } } -func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { +func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) // Work out if the roomserver knows everything it needs to know to auth - // the event. + // the event. This includes the prev_events and auth_events. + // NOTE! This is going to include prev_events that have an empty state + // snapshot. This is because we will need to re-request the event, and + // it's /state_ids, in order for it to exist in the roomserver correctly + // before the roomserver tries to work out stateReq := api.QueryMissingAuthPrevEventsRequest{ RoomID: e.RoomID(), AuthEventIDs: e.AuthEventIDs(), @@ -410,7 +414,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is if len(stateResp.MissingPrevEventIDs) > 0 { logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) - return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) + return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion) } // pass the event to the roomserver which will do auth checks @@ -438,7 +442,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { +func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { // Do this with a fresh context, so that we keep working even if the // original request times out. With any luck, by the time the remote // side retries, we'll have fetched the missing state. @@ -464,39 +468,82 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn) + newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) if err != nil { return err } - if backwardsExtremity == nil { - // we filled in the gap! + if len(newEvents) == 0 { return nil } + backwardsExtremity := &newEvents[0] + newEvents = newEvents[1:] + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. - // security: we have to do state resolution on the new backwards extremity (TODO: WHY) // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. var states []*gomatrixserverlib.RespState needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + // Look up what the state is after the backward extremity. This will either + // come from the roomserver, if we know all the required events, or it will + // come from a remote server via /state_ids if not. var prevState *gomatrixserverlib.RespState prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return err } + // Append the state onto the collected state. We'll run this through the + // state resolution next. states = append(states, prevState) } + + // Now that we have collected all of the state from the prev_events, we'll + // run the state through the appropriate state resolution algorithm for the + // room. This does a couple of things: + // 1. Ensures that the state is deduplicated fully for each state-key tuple + // 2. Ensures that we pick the latest events from both sets, in the case that + // one of the prev_events is quite a bit older than the others resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) return err } - // pass the event along with the state to the roomserver using a background context so we don't - // needlessly expire - return api.SendEventWithState(context.Background(), t.rsAPI, resolvedState, e.Headered(roomVersion), t.haveEventIDs()) + // First of all, send the backward extremity into the roomserver with the + // newly resolved state. This marks the "oldest" point in the backfill and + // sets the baseline state for any new events after this. + err = api.SendEventWithState( + context.Background(), + t.rsAPI, + resolvedState, + backwardsExtremity.Headered(roomVersion), + t.haveEventIDs(), + ) + if err != nil { + return fmt.Errorf("api.SendEventWithState: %w", err) + } + + // Then send all of the newer backfilled events, of which will all be newer + // than the backward extremity, into the roomserver without state. This way + // they will automatically fast-forward based on the room state at the + // extremity in the last step. + headeredNewEvents := make([]gomatrixserverlib.HeaderedEvent, len(newEvents)) + for i, newEvent := range newEvents { + headeredNewEvents[i] = newEvent.Headered(roomVersion) + } + if err = api.SendEvents( + context.Background(), + t.rsAPI, + append(headeredNewEvents, e.Headered(roomVersion)), + api.DoNotSendToOtherServers, + nil, + ); err != nil { + return fmt.Errorf("api.SendEvents: %w", err) + } + + return nil } // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) @@ -652,11 +699,7 @@ retryAllowedState: // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events. // nolint:gocyclo -func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { - if !isInboundTxn { - // we've recursed here, so just take a state snapshot please! - return &e, nil - } +func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.Event, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) // query latest events (our trusted forward extremities) @@ -667,7 +710,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event var res api.QueryLatestEventsAndStateResponse if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { logger.WithError(err).Warn("Failed to query latest events") - return &e, nil + return nil, err } latestEvents := make([]string, len(res.LatestEvents)) for i := range res.LatestEvents { @@ -726,7 +769,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) // topologically sort and sanity check that we are making forward progress - newEvents := gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() hasPrevEvent := false Event: @@ -749,16 +792,9 @@ Event: err: err, } } - // process the missing events then the event which started this whole thing - for _, ev := range append(newEvents, e) { - err := t.processEvent(ctx, ev, false) - if err != nil { - return nil, err - } - } // we processed everything! - return nil, nil + return newEvents, nil } func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index ba653c1e8..d7e422479 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -516,6 +516,23 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions rsAPI = &testRoomserverAPI{ + queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { + res := api.QueryEventsByIDResponse{} + for _, ev := range testEvents { + for _, id := range req.EventIDs { + if ev.EventID() == id { + res.Events = append(res.Events, ev) + } + } + } + return res + }, + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + StateEvents: testEvents[:5], + } + }, queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { missingPrevEvent := []string{"missing_prev_event"} if len(req.PrevEventIDs) == 1 { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 736604217..810511505 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -122,7 +122,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents( } for _, prevEventID := range request.PrevEventIDs { - if nids, err := r.DB.EventNIDs(ctx, []string{prevEventID}); err != nil || len(nids) == 0 { + if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err != nil || len(state) == 0 { response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) } }