Fix QuerySharedUsers for the SyncAPI keychange consumer (#2554)

* Make more use of base.BaseDendrite

* Fix QuerySharedUsers if no UserIDs are supplied
This commit is contained in:
Till 2022-07-05 14:50:56 +02:00 committed by GitHub
parent f29cdb26f6
commit 5087b36af0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 31 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -39,6 +40,7 @@ type RoomserverInternalAPI struct {
*perform.Upgrader *perform.Upgrader
*perform.Admin *perform.Admin
ProcessContext *process.ProcessContext ProcessContext *process.ProcessContext
Base *base.BaseDendrite
DB storage.Database DB storage.Database
Cfg *config.RoomServer Cfg *config.RoomServer
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
@ -56,33 +58,38 @@ type RoomserverInternalAPI struct {
} }
func NewRoomserverAPI( func NewRoomserverAPI(
processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database, base *base.BaseDendrite, roomserverDB storage.Database,
js nats.JetStreamContext, nc *nats.Conn, inputRoomEventTopic string, js nats.JetStreamContext, nc *nats.Conn,
caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
) *RoomserverInternalAPI { ) *RoomserverInternalAPI {
var perspectiveServerNames []gomatrixserverlib.ServerName
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
}
serverACLs := acls.NewServerACLs(roomserverDB) serverACLs := acls.NewServerACLs(roomserverDB)
producer := &producers.RoomEventProducer{ producer := &producers.RoomEventProducer{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent)), Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)),
JetStream: js, JetStream: js,
ACLs: serverACLs, ACLs: serverACLs,
} }
a := &RoomserverInternalAPI{ a := &RoomserverInternalAPI{
ProcessContext: processCtx, ProcessContext: base.ProcessContext,
DB: roomserverDB, DB: roomserverDB,
Cfg: cfg, Base: base,
Cache: caches, Cfg: &base.Cfg.RoomServer,
ServerName: cfg.Matrix.ServerName, Cache: base.Caches,
ServerName: base.Cfg.Global.ServerName,
PerspectiveServerNames: perspectiveServerNames, PerspectiveServerNames: perspectiveServerNames,
InputRoomEventTopic: inputRoomEventTopic, InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent),
OutputProducer: producer, OutputProducer: producer,
JetStream: js, JetStream: js,
NATSClient: nc, NATSClient: nc,
Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"), Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"),
ServerACLs: serverACLs, ServerACLs: serverACLs,
Queryer: &query.Queryer{ Queryer: &query.Queryer{
DB: roomserverDB, DB: roomserverDB,
Cache: caches, Cache: base.Caches,
ServerName: cfg.Matrix.ServerName, ServerName: base.Cfg.Global.ServerName,
ServerACLs: serverACLs, ServerACLs: serverACLs,
}, },
// perform-er structs get initialised when we have a federation sender to use // perform-er structs get initialised when we have a federation sender to use
@ -98,8 +105,9 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.KeyRing = keyRing r.KeyRing = keyRing
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
Cfg: r.Cfg, Cfg: &r.Base.Cfg.RoomServer,
ProcessContext: r.ProcessContext, Base: r.Base,
ProcessContext: r.Base.ProcessContext,
DB: r.DB, DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic, InputRoomEventTopic: r.InputRoomEventTopic,
OutputProducer: r.OutputProducer, OutputProducer: r.OutputProducer,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers" "github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -69,6 +70,7 @@ import (
// or C. // or C.
type Inputer struct { type Inputer struct {
Cfg *config.RoomServer Cfg *config.RoomServer
Base *base.BaseDendrite
ProcessContext *process.ProcessContext ProcessContext *process.ProcessContext
DB storage.Database DB storage.Database
NATSClient *nats.Conn NATSClient *nats.Conn
@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
// will look to see if we have a worker for that room which has its // will look to see if we have a worker for that room which has its
// own consumer. If we don't, we'll start one. // own consumer. If we don't, we'll start one.
func (r *Inputer) Start() error { func (r *Inputer) Start() error {
if r.Base.EnableMetrics {
prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration) prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
}
_, err := r.JetStream.Subscribe( _, err := r.JetStream.Subscribe(
"", // This is blank because we specified it in BindStream. "", // This is blank because we specified it in BindStream.
func(m *nats.Msg) { func(m *nats.Msg) {

View file

@ -17,13 +17,10 @@ package roomserver
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/internal"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,11 +37,6 @@ func NewInternalAPI(
) api.RoomserverInternalAPI { ) api.RoomserverInternalAPI {
cfg := &base.Cfg.RoomServer cfg := &base.Cfg.RoomServer
var perspectiveServerNames []gomatrixserverlib.ServerName
for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
}
roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches) roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to room server db") logrus.WithError(err).Panicf("failed to connect to room server db")
@ -53,8 +45,6 @@ func NewInternalAPI(
js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
return internal.NewRoomserverAPI( return internal.NewRoomserverAPI(
base.ProcessContext, cfg, roomserverDB, js, nc, base, roomserverDB, js, nc,
cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEvent),
base.Caches, perspectiveServerNames,
) )
} }

View file

@ -0,0 +1,69 @@
package roomserver_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func Test_SharedUsers(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite and join Bob
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, _, close := mustCreateDatabase(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err)
}
// Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
})
}

View file

@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
); );
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" + " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
@ -153,6 +159,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt
selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt
@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
roomNIDs []types.RoomNID, roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID, userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) { ) (map[types.EventStateKeyNID]int, error) {
var (
rows *sql.Rows
err error
)
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs)) if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
} else {
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1214,6 +1214,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
stateKeyNIDs[i] = nid stateKeyNIDs[i] = nid
i++ i++
} }
// If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
if len(userIDs) == 0 {
nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
if err != nil {
return nil, err
}
}
result := make(map[string]int, len(userNIDToCount)) result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount { for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count result[nidToUserID[nid]] = count

View file

@ -41,12 +41,18 @@ const membershipSchema = `
); );
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" + " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid" " GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND " +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for _, v := range userNIDs { for _, v := range userNIDs {
params = append(params, v) params = append(params, v)
} }
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
if len(userNIDs) > 0 {
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1) query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
}
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if txn != nil { if txn != nil {