Add QueryEventsAfter

This commit is contained in:
Till Faelligen 2022-02-17 17:10:21 +01:00
parent 5a39512f5f
commit 440a771d10
12 changed files with 121 additions and 2 deletions

View file

@ -166,6 +166,11 @@ type RoomserverInternalAPI interface {
// PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
QueryEventsAfter(
ctx context.Context,
req *QueryEventsAfterEventIDRequest,
res *QueryEventsAfterEventIDesponse,
) error
// Asks for the default room version as preferred by the server.
QueryRoomVersionCapabilities(

View file

@ -17,6 +17,11 @@ type RoomserverInternalAPITrace struct {
Impl RoomserverInternalAPI
}
func (t *RoomserverInternalAPITrace) QueryEventsAfter(ctx context.Context, req *QueryEventsAfterEventIDRequest, res *QueryEventsAfterEventIDesponse) error {
util.GetLogger(ctx).Infof("QueryEventsAfter req=%+v res=%+v", js(req), js(res))
return t.Impl.QueryEventsAfter(ctx, req, res)
}
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) {
t.Impl.SetFederationAPI(fsAPI, keyRing)
}

View file

@ -101,6 +101,17 @@ type QueryEventsByIDResponse struct {
Events []*gomatrixserverlib.HeaderedEvent `json:"events"`
}
// QueryEventsByIDRequest is a request to QueryEventsByID
type QueryEventsAfterEventIDRequest struct {
// The event IDs to look up.
EventIDs string `json:"event_id"`
}
// QueryEventsByIDResponse is a response to QueryEventsByID
type QueryEventsAfterEventIDesponse struct {
Events []*gomatrixserverlib.ClientEvent `json:"events"`
}
// QueryMembershipForUserRequest is a request to QueryMembership
type QueryMembershipForUserRequest struct {
// ID of the room to fetch membership from

View file

@ -198,3 +198,11 @@ func (r *RoomserverInternalAPI) PerformForget(
) error {
return r.Forgetter.PerformForget(ctx, req, resp)
}
func (r *RoomserverInternalAPI) QueryEventsAfter(
ctx context.Context,
req *api.QueryEventsAfterEventIDRequest,
res *api.QueryEventsAfterEventIDesponse,
) error {
return r.Queryer.QueryEventsAfter(ctx, req, res)
}

View file

@ -724,3 +724,23 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
res.AuthChain = hchain
return nil
}
func (r *Queryer) QueryEventsAfter(
ctx context.Context,
req *api.QueryEventsAfterEventIDRequest,
res *api.QueryEventsAfterEventIDesponse,
) error {
eventNIDs, err := r.DB.SelectPreviousEventNIDs(ctx, req.EventIDs)
if err != nil {
return err
}
events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return err
}
for _, event := range events {
ev := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
res.Events = append(res.Events, &ev)
}
return nil
}

View file

@ -57,6 +57,7 @@ const (
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryEventsAfterPath = "/roomserver/queryEventsAfter"
)
type httpRoomserverInternalAPI struct {
@ -534,5 +535,12 @@ func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.
apiURL := h.roomserverURL + RoomserverPerformForgetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpRoomserverInternalAPI) QueryEventsAfter(ctx context.Context, req *api.QueryEventsAfterEventIDRequest, res *api.QueryEventsAfterEventIDesponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsAfter")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryEventsAfterPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -464,4 +464,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(RoomserverQueryEventsAfterPath,
httputil.MakeInternalAPI("queryEventsAfterPath", func(req *http.Request) util.JSONResponse {
request := api.QueryEventsAfterEventIDRequest{}
response := api.QueryEventsAfterEventIDesponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryEventsAfter(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -64,6 +64,7 @@ type Database interface {
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
SelectPreviousEventNIDs(ctx context.Context, eventID string) ([]types.EventNID, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.

View file

@ -59,9 +59,14 @@ const selectPreviousEventExistsSQL = "" +
"SELECT 1 FROM roomserver_previous_events" +
" WHERE previous_event_id = $1 AND previous_reference_sha256 = $2"
const selectPreviousEventNIDsSQL = "" +
"SELECT event_nids FROM roomserver_previous_events" +
" WHERE previous_event_id = $1"
type previousEventStatements struct {
insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
}
func createPrevEventsTable(db *sql.DB) error {
@ -75,6 +80,7 @@ func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
return s, sqlutil.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
}.Prepare(db)
}
@ -101,3 +107,18 @@ func (s *previousEventStatements) SelectPreviousEventExists(
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
}
// SelectPreviousEventNIDs returns all eventNIDs for a given eventID
func (s *previousEventStatements) SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventNIDsStmt)
row := stmt.QueryRowContext(ctx, eventID)
var eventNIDs []uint8
if err := row.Scan(&eventNIDs); err != nil {
return nil, err
}
result := []types.EventNID{}
for _, nid := range eventNIDs {
result = append(result, types.EventNID(nid))
}
return result, nil
}

View file

@ -1173,6 +1173,10 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
})
}
func (d *Database) SelectPreviousEventNIDs(ctx context.Context, eventID string) ([]types.EventNID, error) {
return d.PrevEventsTable.SelectPreviousEventNIDs(ctx, nil, eventID)
}
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package!

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -53,7 +54,7 @@ const insertPreviousEventSQL = `
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
WHERE previous_event_id = $1
`
// Check if the event is referenced by another event in the table.
@ -129,3 +130,24 @@ func (s *previousEventStatements) SelectPreviousEventExists(
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
}
// SelectPreviousEventNIDs returns all eventNIDs for a given eventID
func (s *previousEventStatements) SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectPreviousEventNIDsStmt)
row := stmt.QueryRowContext(ctx, eventID)
var eventNIDs string
if err := row.Scan(&eventNIDs); err != nil {
return nil, err
}
result := []types.EventNID{}
nids := strings.Split(eventNIDs, ",")
for _, nid := range nids {
i, err := strconv.Atoi(nid)
if err != nil {
return nil, err
}
result = append(result, types.EventNID(i))
}
return result, nil
}

View file

@ -100,6 +100,7 @@ type PreviousEvents interface {
// Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist.
SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error
SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error)
}
type Invites interface {