mirror of
https://github.com/matrix-org/dendrite
synced 2024-12-12 22:33:00 +00:00
bugfix: E2EE device keys could sometimes not be sent to remote servers (#2466)
* Fix flakey sytest 'Local device key changes get to remote servers' * Debug logs * Remove internal/test and use /test only Remove a lot of ancient code too. * Use FederationRoomserverAPI in more places * Use more interfaces in federationapi; begin adding regression test * Linting * Add regression test * Unbreak tests * ALL THE LOGS * Fix a race condition which could cause events to not be sent to servers If a new room event which rewrites state arrives, we remove all joined hosts then re-calculate them. This wasn't done in a transaction so for a brief period we would have no joined hosts. During this interim, key change events which arrive would not be sent to destination servers. This would sporadically fail on sytest. * Unbreak new tests * Linting
This commit is contained in:
parent
cd82460513
commit
6de29c1cd2
48 changed files with 566 additions and 618 deletions
|
@ -20,7 +20,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usage = `Usage: %s
|
const usage = `Usage: %s
|
||||||
|
|
|
@ -12,12 +12,16 @@ import (
|
||||||
|
|
||||||
// FederationInternalAPI is used to query information from the federation sender.
|
// FederationInternalAPI is used to query information from the federation sender.
|
||||||
type FederationInternalAPI interface {
|
type FederationInternalAPI interface {
|
||||||
FederationClient
|
gomatrixserverlib.FederatedStateClient
|
||||||
|
KeyserverFederationAPI
|
||||||
gomatrixserverlib.KeyDatabase
|
gomatrixserverlib.KeyDatabase
|
||||||
ClientFederationAPI
|
ClientFederationAPI
|
||||||
RoomserverFederationAPI
|
RoomserverFederationAPI
|
||||||
|
|
||||||
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
|
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
|
||||||
|
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||||
|
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||||
|
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||||
|
|
||||||
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
|
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
|
||||||
PerformBroadcastEDU(
|
PerformBroadcastEDU(
|
||||||
|
@ -60,17 +64,43 @@ type RoomserverFederationAPI interface {
|
||||||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender
|
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
|
||||||
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
||||||
// this interface are of type FederationClientError
|
// this interface are of type FederationClientError
|
||||||
type FederationClient interface {
|
type KeyserverFederationAPI interface {
|
||||||
gomatrixserverlib.FederatedStateClient
|
|
||||||
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
||||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
||||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
|
||||||
|
type FederationClient interface {
|
||||||
|
gomatrixserverlib.KeyClient
|
||||||
|
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
|
||||||
|
|
||||||
|
// Perform operations
|
||||||
|
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
|
||||||
|
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
|
||||||
|
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
|
||||||
|
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
|
||||||
|
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
|
||||||
|
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
|
||||||
|
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
|
||||||
|
|
||||||
|
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||||
|
|
||||||
|
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
||||||
|
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
|
||||||
|
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
|
||||||
|
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
|
||||||
|
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
|
||||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
|
||||||
|
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
|
||||||
|
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
|
||||||
|
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
|
||||||
|
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
||||||
|
|
|
@ -39,7 +39,7 @@ type KeyChangeConsumer struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
queues *queue.OutgoingQueues
|
queues *queue.OutgoingQueues
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
rsAPI roomserverAPI.FederationRoomserverAPI
|
||||||
topic string
|
topic string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ func NewKeyChangeConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
queues *queue.OutgoingQueues,
|
queues *queue.OutgoingQueues,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||||
) *KeyChangeConsumer {
|
) *KeyChangeConsumer {
|
||||||
return &KeyChangeConsumer{
|
return &KeyChangeConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
|
||||||
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
logrus.Infof("DEBUG: %v joined rooms for user %v", queryRes.RoomIDs, m.UserID)
|
||||||
// send this key change to all servers who share rooms with this user.
|
// send this key change to all servers who share rooms with this user.
|
||||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -128,6 +129,9 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(destinations) == 0 {
|
if len(destinations) == 0 {
|
||||||
|
logger.WithField("num_rooms", len(queryRes.RoomIDs)).Debug("user is in no federated rooms")
|
||||||
|
destinations, err = t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, false)
|
||||||
|
logrus.Infof("GetJoinedHostsForRooms exclude self=false -> %v %v", destinations, err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Pack the EDU and marshal it
|
// Pack the EDU and marshal it
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||||
|
@ -36,7 +37,7 @@ import (
|
||||||
type OutputRoomEventConsumer struct {
|
type OutputRoomEventConsumer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cfg *config.FederationAPI
|
cfg *config.FederationAPI
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.FederationRoomserverAPI
|
||||||
jetstream nats.JetStreamContext
|
jetstream nats.JetStreamContext
|
||||||
durable string
|
durable string
|
||||||
db storage.Database
|
db storage.Database
|
||||||
|
@ -51,7 +52,7 @@ func NewOutputRoomEventConsumer(
|
||||||
js nats.JetStreamContext,
|
js nats.JetStreamContext,
|
||||||
queues *queue.OutgoingQueues,
|
queues *queue.OutgoingQueues,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
) *OutputRoomEventConsumer {
|
) *OutputRoomEventConsumer {
|
||||||
return &OutputRoomEventConsumer{
|
return &OutputRoomEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
|
@ -89,15 +90,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
||||||
switch output.Type {
|
switch output.Type {
|
||||||
case api.OutputTypeNewRoomEvent:
|
case api.OutputTypeNewRoomEvent:
|
||||||
ev := output.NewRoomEvent.Event
|
ev := output.NewRoomEvent.Event
|
||||||
|
if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil {
|
||||||
if output.NewRoomEvent.RewritesState {
|
|
||||||
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
|
|
||||||
log.WithError(err).Errorf("roomserver output log: purge room state failure")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
|
||||||
// panic rather than continue with an inconsistent database
|
// panic rather than continue with an inconsistent database
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"event_id": ev.EventID(),
|
"event_id": ev.EventID(),
|
||||||
|
@ -145,7 +138,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
|
||||||
|
|
||||||
// processMessage updates the list of currently joined hosts in the room
|
// processMessage updates the list of currently joined hosts in the room
|
||||||
// and then sends the event to the hosts that were joined before the event.
|
// and then sends the event to the hosts that were joined before the event.
|
||||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
|
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
|
||||||
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
|
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
|
||||||
|
|
||||||
// Ask the roomserver and add in the rest of the results into the set.
|
// Ask the roomserver and add in the rest of the results into the set.
|
||||||
|
@ -164,7 +157,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
|
||||||
addsStateEvents = append(addsStateEvents, eventsRes.Events...)
|
addsStateEvents = append(addsStateEvents, eventsRes.Events...)
|
||||||
}
|
}
|
||||||
|
|
||||||
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
|
addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -173,13 +166,13 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
|
||||||
// expressed as a delta against the current state.
|
// expressed as a delta against the current state.
|
||||||
// TODO(#290): handle EventIDMismatchError and recover the current state by
|
// TODO(#290): handle EventIDMismatchError and recover the current state by
|
||||||
// talking to the roomserver
|
// talking to the roomserver
|
||||||
|
logrus.Infof("room %s adds joined hosts: %v removes %v", ore.Event.RoomID(), addsJoinedHosts, ore.RemovesStateEventIDs)
|
||||||
oldJoinedHosts, err := s.db.UpdateRoom(
|
oldJoinedHosts, err := s.db.UpdateRoom(
|
||||||
s.ctx,
|
s.ctx,
|
||||||
ore.Event.RoomID(),
|
ore.Event.RoomID(),
|
||||||
ore.LastSentEventID,
|
|
||||||
ore.Event.EventID(),
|
|
||||||
addsJoinedHosts,
|
addsJoinedHosts,
|
||||||
ore.RemovesStateEventIDs,
|
ore.RemovesStateEventIDs,
|
||||||
|
rewritesState, // if we're re-writing state, nuke all joined hosts before adding
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -238,7 +231,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
|
combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -284,10 +277,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
|
// JoinedHostsFromEvents turns a list of state events into a list of joined hosts.
|
||||||
// This errors if one of the events was invalid.
|
// This errors if one of the events was invalid.
|
||||||
// It should be impossible for an invalid event to get this far in the pipeline.
|
// It should be impossible for an invalid event to get this far in the pipeline.
|
||||||
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
|
func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
|
||||||
var joinedHosts []types.JoinedHost
|
var joinedHosts []types.JoinedHost
|
||||||
for _, ev := range evs {
|
for _, ev := range evs {
|
||||||
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
|
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
|
||||||
|
|
|
@ -93,8 +93,8 @@ func AddPublicRoutes(
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
base *base.BaseDendrite,
|
base *base.BaseDendrite,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation api.FederationClient,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||||
caches *caching.Caches,
|
caches *caching.Caches,
|
||||||
keyRing *gomatrixserverlib.KeyRing,
|
keyRing *gomatrixserverlib.KeyRing,
|
||||||
resetBlacklist bool,
|
resetBlacklist bool,
|
||||||
|
|
|
@ -3,18 +3,250 @@ package federationapi_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi"
|
"github.com/matrix-org/dendrite/federationapi"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/federationapi/internal"
|
"github.com/matrix-org/dendrite/federationapi/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/test"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type fedRoomserverAPI struct {
|
||||||
|
rsapi.FederationRoomserverAPI
|
||||||
|
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
|
||||||
|
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformJoin will call this function
|
||||||
|
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
||||||
|
if f.inputRoomEvents == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.inputRoomEvents(ctx, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// keychange consumer calls this
|
||||||
|
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
|
||||||
|
if f.queryRoomsForUser == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.queryRoomsForUser(ctx, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
|
||||||
|
type fedClient struct {
|
||||||
|
api.FederationClient
|
||||||
|
allowJoins []*test.Room
|
||||||
|
keys map[gomatrixserverlib.ServerName]struct {
|
||||||
|
key ed25519.PrivateKey
|
||||||
|
keyID gomatrixserverlib.KeyID
|
||||||
|
}
|
||||||
|
t *testing.T
|
||||||
|
sentTxn bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
|
||||||
|
fmt.Println("GetServerKeys:", matrixServer)
|
||||||
|
var keys gomatrixserverlib.ServerKeys
|
||||||
|
var keyID gomatrixserverlib.KeyID
|
||||||
|
var pkey ed25519.PrivateKey
|
||||||
|
for srv, data := range f.keys {
|
||||||
|
if srv == matrixServer {
|
||||||
|
pkey = data.key
|
||||||
|
keyID = data.keyID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pkey == nil {
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys.ServerName = matrixServer
|
||||||
|
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
|
||||||
|
publicKey := pkey.Public().(ed25519.PublicKey)
|
||||||
|
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
|
||||||
|
keyID: {
|
||||||
|
Key: gomatrixserverlib.Base64Bytes(publicKey),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
toSign, err := json.Marshal(keys.ServerKeyFields)
|
||||||
|
if err != nil {
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keys.Raw, err = gomatrixserverlib.SignJSON(
|
||||||
|
string(matrixServer), keyID, pkey, toSign,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return keys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
|
||||||
|
for _, r := range f.allowJoins {
|
||||||
|
if r.ID == roomID {
|
||||||
|
res.RoomVersion = r.Version
|
||||||
|
res.JoinEvent = gomatrixserverlib.EventBuilder{
|
||||||
|
Sender: userID,
|
||||||
|
RoomID: roomID,
|
||||||
|
Type: "m.room.member",
|
||||||
|
StateKey: &userID,
|
||||||
|
Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)),
|
||||||
|
PrevEvents: r.ForwardExtremities(),
|
||||||
|
}
|
||||||
|
var needed gomatrixserverlib.StateNeeded
|
||||||
|
needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent)
|
||||||
|
if err != nil {
|
||||||
|
f.t.Errorf("StateNeededForEventBuilder: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
|
||||||
|
for _, r := range f.allowJoins {
|
||||||
|
if r.ID == event.RoomID() {
|
||||||
|
r.InsertEvent(f.t, event.Headered(r.Version))
|
||||||
|
f.t.Logf("Join event: %v", event.EventID())
|
||||||
|
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState())
|
||||||
|
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
|
||||||
|
for _, edu := range t.EDUs {
|
||||||
|
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
|
||||||
|
f.sentTxn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.t.Logf("got /send")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
|
||||||
|
// isn't relying on the roomserver.
|
||||||
|
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
testFederationAPIJoinThenKeyUpdate(t, dbType)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
|
||||||
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
base.Cfg.FederationAPI.PreferDirectFetch = true
|
||||||
|
defer close()
|
||||||
|
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
|
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||||
|
|
||||||
|
serverA := gomatrixserverlib.ServerName("server.a")
|
||||||
|
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
|
||||||
|
serverAPrivKey := test.PrivateKeyA
|
||||||
|
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
|
||||||
|
|
||||||
|
myServer := base.Cfg.Global.ServerName
|
||||||
|
myServerKeyID := base.Cfg.Global.KeyID
|
||||||
|
myServerPrivKey := base.Cfg.Global.PrivateKey
|
||||||
|
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
|
||||||
|
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
|
||||||
|
room := test.NewRoom(t, creator)
|
||||||
|
|
||||||
|
rsapi := &fedRoomserverAPI{
|
||||||
|
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
||||||
|
if req.Asynchronous {
|
||||||
|
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
|
||||||
|
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
|
||||||
|
res.RoomIDs = []string{room.ID}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fc := &fedClient{
|
||||||
|
allowJoins: []*test.Room{room},
|
||||||
|
t: t,
|
||||||
|
keys: map[gomatrixserverlib.ServerName]struct {
|
||||||
|
key ed25519.PrivateKey
|
||||||
|
keyID gomatrixserverlib.KeyID
|
||||||
|
}{
|
||||||
|
serverA: {
|
||||||
|
key: serverAPrivKey,
|
||||||
|
keyID: serverAKeyID,
|
||||||
|
},
|
||||||
|
myServer: {
|
||||||
|
key: myServerPrivKey,
|
||||||
|
keyID: myServerKeyID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
|
||||||
|
|
||||||
|
var resp api.PerformJoinResponse
|
||||||
|
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
|
||||||
|
RoomID: room.ID,
|
||||||
|
UserID: joiningUser.ID,
|
||||||
|
ServerNames: []gomatrixserverlib.ServerName{serverA},
|
||||||
|
}, &resp)
|
||||||
|
if resp.JoinedVia != serverA {
|
||||||
|
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
|
||||||
|
}
|
||||||
|
if resp.LastError != nil {
|
||||||
|
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
|
||||||
|
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
|
||||||
|
// hosts table.
|
||||||
|
key := keyapi.DeviceMessage{
|
||||||
|
Type: keyapi.TypeDeviceKeyUpdate,
|
||||||
|
DeviceKeys: &keyapi.DeviceKeys{
|
||||||
|
UserID: joiningUser.ID,
|
||||||
|
DeviceID: "MY_DEVICE",
|
||||||
|
DisplayName: "BLARGLE",
|
||||||
|
KeyJSON: []byte(`{}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal device message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := &nats.Msg{
|
||||||
|
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
|
||||||
|
Header: nats.Header{},
|
||||||
|
Data: b,
|
||||||
|
}
|
||||||
|
msg.Header.Set(jetstream.UserID, key.UserID)
|
||||||
|
|
||||||
|
testrig.MustPublishMsgs(t, jsctx, msg)
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
if !fc.sentTxn {
|
||||||
|
t.Fatalf("did not send device list update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
|
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
|
||||||
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
|
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
|
||||||
func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||||
|
@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||||
}
|
}
|
||||||
gerr, ok := err.(gomatrix.HTTPError)
|
gerr, ok := err.(gomatrix.HTTPError)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("failed to cast response error as gomatrix.HTTPError")
|
t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.Logf("Error: %+v", gerr)
|
t.Logf("Error: %+v", gerr)
|
||||||
|
|
|
@ -25,8 +25,8 @@ type FederationInternalAPI struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
cfg *config.FederationAPI
|
cfg *config.FederationAPI
|
||||||
statistics *statistics.Statistics
|
statistics *statistics.Statistics
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
rsAPI roomserverAPI.FederationRoomserverAPI
|
||||||
federation *gomatrixserverlib.FederationClient
|
federation api.FederationClient
|
||||||
keyRing *gomatrixserverlib.KeyRing
|
keyRing *gomatrixserverlib.KeyRing
|
||||||
queues *queue.OutgoingQueues
|
queues *queue.OutgoingQueues
|
||||||
joins sync.Map // joins currently in progress
|
joins sync.Map // joins currently in progress
|
||||||
|
@ -34,8 +34,8 @@ type FederationInternalAPI struct {
|
||||||
|
|
||||||
func NewFederationInternalAPI(
|
func NewFederationInternalAPI(
|
||||||
db storage.Database, cfg *config.FederationAPI,
|
db storage.Database, cfg *config.FederationAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation api.FederationClient,
|
||||||
statistics *statistics.Statistics,
|
statistics *statistics.Statistics,
|
||||||
caches *caching.Caches,
|
caches *caching.Caches,
|
||||||
queues *queue.OutgoingQueues,
|
queues *queue.OutgoingQueues,
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/api"
|
"github.com/matrix-org/dendrite/federationapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/version"
|
"github.com/matrix-org/dendrite/roomserver/version"
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
|
@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
return fmt.Errorf("respSendJoin.Check: %w", err)
|
return fmt.Errorf("respSendJoin.Check: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need to immediately update our list of joined hosts for this room now as we are technically
|
||||||
|
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
|
||||||
|
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
|
||||||
|
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
||||||
|
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
||||||
|
// The events are trusted now as we performed auth checks above.
|
||||||
|
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
||||||
|
}
|
||||||
|
logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts")
|
||||||
|
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
|
||||||
|
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// If we successfully performed a send_join above then the other
|
// If we successfully performed a send_join above then the other
|
||||||
// server now thinks we're a part of the room. Send the newly
|
// server now thinks we're a part of the room. Send the newly
|
||||||
// returned state to the roomserver to update our local view.
|
// returned state to the roomserver to update our local view.
|
||||||
|
@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
|
||||||
|
|
||||||
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
|
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
|
||||||
func federatedAuthProvider(
|
func federatedAuthProvider(
|
||||||
ctx context.Context, federation *gomatrixserverlib.FederationClient,
|
ctx context.Context, federation api.FederationClient,
|
||||||
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
|
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
|
||||||
) gomatrixserverlib.AuthChainProvider {
|
) gomatrixserverlib.AuthChainProvider {
|
||||||
// A list of events that we have retried, if they were not included in
|
// A list of events that we have retried, if they were not included in
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||||
|
@ -49,8 +50,8 @@ type destinationQueue struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
process *process.ProcessContext
|
process *process.ProcessContext
|
||||||
signing *SigningInfo
|
signing *SigningInfo
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.FederationRoomserverAPI
|
||||||
client *gomatrixserverlib.FederationClient // federation client
|
client fedapi.FederationClient // federation client
|
||||||
origin gomatrixserverlib.ServerName // origin of requests
|
origin gomatrixserverlib.ServerName // origin of requests
|
||||||
destination gomatrixserverlib.ServerName // destination of requests
|
destination gomatrixserverlib.ServerName // destination of requests
|
||||||
running atomic.Bool // is the queue worker running?
|
running atomic.Bool // is the queue worker running?
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||||
|
@ -39,9 +40,9 @@ type OutgoingQueues struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
process *process.ProcessContext
|
process *process.ProcessContext
|
||||||
disabled bool
|
disabled bool
|
||||||
rsAPI api.RoomserverInternalAPI
|
rsAPI api.FederationRoomserverAPI
|
||||||
origin gomatrixserverlib.ServerName
|
origin gomatrixserverlib.ServerName
|
||||||
client *gomatrixserverlib.FederationClient
|
client fedapi.FederationClient
|
||||||
statistics *statistics.Statistics
|
statistics *statistics.Statistics
|
||||||
signing *SigningInfo
|
signing *SigningInfo
|
||||||
queuesMutex sync.Mutex // protects the below
|
queuesMutex sync.Mutex // protects the below
|
||||||
|
@ -85,8 +86,8 @@ func NewOutgoingQueues(
|
||||||
process *process.ProcessContext,
|
process *process.ProcessContext,
|
||||||
disabled bool,
|
disabled bool,
|
||||||
origin gomatrixserverlib.ServerName,
|
origin gomatrixserverlib.ServerName,
|
||||||
client *gomatrixserverlib.FederationClient,
|
client fedapi.FederationClient,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
statistics *statistics.Statistics,
|
statistics *statistics.Statistics,
|
||||||
signing *SigningInfo,
|
signing *SigningInfo,
|
||||||
) *OutgoingQueues {
|
) *OutgoingQueues {
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
// RoomAliasToID converts the queried alias into a room ID and returns it
|
// RoomAliasToID converts the queried alias into a room ID and returns it
|
||||||
func RoomAliasToID(
|
func RoomAliasToID(
|
||||||
httpReq *http.Request,
|
httpReq *http.Request,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation federationAPI.FederationClient,
|
||||||
cfg *config.FederationAPI,
|
cfg *config.FederationAPI,
|
||||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||||
senderAPI federationAPI.FederationInternalAPI,
|
senderAPI federationAPI.FederationInternalAPI,
|
||||||
|
|
|
@ -54,7 +54,7 @@ func Setup(
|
||||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||||
fsAPI *fedInternal.FederationInternalAPI,
|
fsAPI *fedInternal.FederationInternalAPI,
|
||||||
keys gomatrixserverlib.JSONVerifier,
|
keys gomatrixserverlib.JSONVerifier,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation federationAPI.FederationClient,
|
||||||
userAPI userapi.FederationUserAPI,
|
userAPI userapi.FederationUserAPI,
|
||||||
keyAPI keyserverAPI.FederationKeyAPI,
|
keyAPI keyserverAPI.FederationKeyAPI,
|
||||||
mscCfg *config.MSCs,
|
mscCfg *config.MSCs,
|
||||||
|
|
|
@ -85,7 +85,7 @@ func Send(
|
||||||
rsAPI api.FederationRoomserverAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
keyAPI keyapi.FederationKeyAPI,
|
keyAPI keyapi.FederationKeyAPI,
|
||||||
keys gomatrixserverlib.JSONVerifier,
|
keys gomatrixserverlib.JSONVerifier,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation federationAPI.FederationClient,
|
||||||
mu *internal.MutexByRoom,
|
mu *internal.MutexByRoom,
|
||||||
servers federationAPI.ServersInRoomProvider,
|
servers federationAPI.ServersInRoomProvider,
|
||||||
producer *producers.SyncAPIProducer,
|
producer *producers.SyncAPIProducer,
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/test"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -57,7 +58,7 @@ var (
|
||||||
func CreateInvitesFrom3PIDInvites(
|
func CreateInvitesFrom3PIDInvites(
|
||||||
req *http.Request, rsAPI api.FederationRoomserverAPI,
|
req *http.Request, rsAPI api.FederationRoomserverAPI,
|
||||||
cfg *config.FederationAPI,
|
cfg *config.FederationAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation federationAPI.FederationClient,
|
||||||
userAPI userapi.FederationUserAPI,
|
userAPI userapi.FederationUserAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var body invites
|
var body invites
|
||||||
|
@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite(
|
||||||
roomID string,
|
roomID string,
|
||||||
rsAPI api.FederationRoomserverAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
cfg *config.FederationAPI,
|
cfg *config.FederationAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation federationAPI.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var builder gomatrixserverlib.EventBuilder
|
var builder gomatrixserverlib.EventBuilder
|
||||||
if err := json.Unmarshal(request.Content(), &builder); err != nil {
|
if err := json.Unmarshal(request.Content(), &builder); err != nil {
|
||||||
|
@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite(
|
||||||
|
|
||||||
// Ask the requesting server to sign the newly created event so we know it
|
// Ask the requesting server to sign the newly created event so we know it
|
||||||
// acknowledged it
|
// acknowledged it
|
||||||
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event)
|
inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
|
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite(
|
||||||
func createInviteFrom3PIDInvite(
|
func createInviteFrom3PIDInvite(
|
||||||
ctx context.Context, rsAPI api.FederationRoomserverAPI,
|
ctx context.Context, rsAPI api.FederationRoomserverAPI,
|
||||||
cfg *config.FederationAPI,
|
cfg *config.FederationAPI,
|
||||||
inv invite, federation *gomatrixserverlib.FederationClient,
|
inv invite, federation federationAPI.FederationClient,
|
||||||
userAPI userapi.FederationUserAPI,
|
userAPI userapi.FederationUserAPI,
|
||||||
) (*gomatrixserverlib.Event, error) {
|
) (*gomatrixserverlib.Event, error) {
|
||||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
|
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
|
||||||
|
@ -335,7 +341,7 @@ func buildMembershipEvent(
|
||||||
// them responded with an error.
|
// them responded with an error.
|
||||||
func sendToRemoteServer(
|
func sendToRemoteServer(
|
||||||
ctx context.Context, inv invite,
|
ctx context.Context, inv invite,
|
||||||
federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI,
|
federation federationAPI.FederationClient, _ *config.FederationAPI,
|
||||||
builder gomatrixserverlib.EventBuilder,
|
builder gomatrixserverlib.EventBuilder,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
||||||
|
|
|
@ -25,13 +25,12 @@ import (
|
||||||
type Database interface {
|
type Database interface {
|
||||||
gomatrixserverlib.KeyDatabase
|
gomatrixserverlib.KeyDatabase
|
||||||
|
|
||||||
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
|
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
|
||||||
|
|
||||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
|
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
|
||||||
PurgeRoomState(ctx context.Context, roomID string) error
|
|
||||||
|
|
||||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const joinedHostsSchema = `
|
const joinedHostsSchema = `
|
||||||
|
@ -111,6 +112,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
|
logrus.Debugf("FederationJoinedHosts: INSERT %v %v %v", roomID, eventID, serverName)
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||||
return err
|
return err
|
||||||
|
@ -119,6 +121,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) error {
|
) error {
|
||||||
|
logrus.Debugf("FederationJoinedHosts: DELETE WITH EVENTS %v", eventIDs)
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||||
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
||||||
return err
|
return err
|
||||||
|
@ -127,6 +130,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) error {
|
) error {
|
||||||
|
logrus.Debugf("FederationJoinedHosts: DELETE ALL IN ROOM %v", roomID)
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||||
_, err := stmt.ExecContext(ctx, roomID)
|
_, err := stmt.ExecContext(ctx, roomID)
|
||||||
return err
|
return err
|
||||||
|
@ -207,6 +211,7 @@ func joinedHostsFromStmt(
|
||||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
logrus.Debugf("FederationJoinedHosts: SELECT %v => %+v", roomID, result)
|
||||||
|
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,11 +63,21 @@ func (r *Receipt) String() string {
|
||||||
// this isn't a duplicate message.
|
// this isn't a duplicate message.
|
||||||
func (d *Database) UpdateRoom(
|
func (d *Database) UpdateRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomID, oldEventID, newEventID string,
|
roomID string,
|
||||||
addHosts []types.JoinedHost,
|
addHosts []types.JoinedHost,
|
||||||
removeHosts []string,
|
removeHosts []string,
|
||||||
|
purgeRoomFirst bool,
|
||||||
) (joinedHosts []types.JoinedHost, err error) {
|
) (joinedHosts []types.JoinedHost, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
if purgeRoomFirst {
|
||||||
|
// If the event is a create event then we'll delete all of the existing
|
||||||
|
// data for the room. The only reason that a create event would be replayed
|
||||||
|
// to us in this way is if we're about to receive the entire room state.
|
||||||
|
if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
||||||
|
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
|
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -138,20 +148,6 @@ func (d *Database) StoreJSON(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) PurgeRoomState(
|
|
||||||
ctx context.Context, roomID string,
|
|
||||||
) error {
|
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
|
||||||
// If the event is a create event then we'll delete all of the existing
|
|
||||||
// data for the room. The only reason that a create event would be replayed
|
|
||||||
// to us in this way is if we're about to receive the entire room state.
|
|
||||||
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
|
||||||
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEDUCache(t *testing.T) {
|
func TestEDUCache(t *testing.T) {
|
||||||
|
|
|
@ -1,158 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Request contains the information necessary to issue a request and test its result
|
|
||||||
type Request struct {
|
|
||||||
Req *http.Request
|
|
||||||
WantedBody string
|
|
||||||
WantedStatusCode int
|
|
||||||
LastErr *LastRequestErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// LastRequestErr is a synchronised error wrapper
|
|
||||||
// Useful for obtaining the last error from a set of requests
|
|
||||||
type LastRequestErr struct {
|
|
||||||
sync.Mutex
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set sets the error
|
|
||||||
func (r *LastRequestErr) Set(err error) {
|
|
||||||
r.Lock()
|
|
||||||
defer r.Unlock()
|
|
||||||
r.Err = err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get gets the error
|
|
||||||
func (r *LastRequestErr) Get() error {
|
|
||||||
r.Lock()
|
|
||||||
defer r.Unlock()
|
|
||||||
return r.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CanonicalJSONInput canonicalises a slice of JSON strings
|
|
||||||
// Useful for test input
|
|
||||||
func CanonicalJSONInput(jsonData []string) []string {
|
|
||||||
for i := range jsonData {
|
|
||||||
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
jsonData[i] = string(jsonBytes)
|
|
||||||
}
|
|
||||||
return jsonData
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do issues a request and checks the status code and body of the response
|
|
||||||
func (r *Request) Do() (err error) {
|
|
||||||
client := &http.Client{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
res, err := client.Do(r.Req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer (func() { err = res.Body.Close() })()
|
|
||||||
|
|
||||||
if res.StatusCode != r.WantedStatusCode {
|
|
||||||
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.WantedBody != "" {
|
|
||||||
resBytes, err := ioutil.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if string(jsonBytes) != r.WantedBody {
|
|
||||||
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
|
|
||||||
// It then closes the given channel and returns.
|
|
||||||
func (r *Request) DoUntilSuccess(done chan error) {
|
|
||||||
r.LastErr = &LastRequestErr{}
|
|
||||||
for {
|
|
||||||
if err := r.Do(); err != nil {
|
|
||||||
r.LastErr.Set(err)
|
|
||||||
time.Sleep(1 * time.Second) // don't tightloop
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
close(done)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run repeatedly issues a request until success, error or a timeout is reached
|
|
||||||
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
|
|
||||||
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
|
|
||||||
done := make(chan error, 1)
|
|
||||||
|
|
||||||
// We need to wait for the server to:
|
|
||||||
// - have connected to the database
|
|
||||||
// - have created the tables
|
|
||||||
// - be listening on the given port
|
|
||||||
go r.DoUntilSuccess(done)
|
|
||||||
|
|
||||||
// wait for one of:
|
|
||||||
// - the test to pass (done channel is closed)
|
|
||||||
// - the server to exit with an error (error sent on serverCmdChan)
|
|
||||||
// - our test timeout to expire
|
|
||||||
// We don't need to clean up since the main() function handles that in the event we panic
|
|
||||||
select {
|
|
||||||
case <-time.After(timeout):
|
|
||||||
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
|
|
||||||
if reqErr := r.LastErr.Get(); reqErr != nil {
|
|
||||||
fmt.Println("Last /sync request error:")
|
|
||||||
fmt.Println(reqErr)
|
|
||||||
}
|
|
||||||
panic(fmt.Sprintf("%v server timed out", label))
|
|
||||||
case err := <-serverCmdChan:
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("=============================================================================================")
|
|
||||||
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
|
|
||||||
fmt.Println(" export PGHOST=/var/run/postgresql")
|
|
||||||
fmt.Println("=============================================================================================")
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
case <-done:
|
|
||||||
fmt.Printf("==TESTING== %v PASSED\n", label)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KafkaExecutor executes kafka scripts.
|
|
||||||
type KafkaExecutor struct {
|
|
||||||
// The location of Zookeeper. Typically this is `localhost:2181`.
|
|
||||||
ZookeeperURI string
|
|
||||||
// The directory where Kafka is installed to. Used to locate kafka scripts.
|
|
||||||
KafkaDirectory string
|
|
||||||
// The location of the Kafka logs. Typically this is `localhost:9092`.
|
|
||||||
KafkaURI string
|
|
||||||
// Where stdout and stderr should be written to. Typically this is `os.Stderr`.
|
|
||||||
OutputWriter io.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTopic creates a new kafka topic. This is created with a single partition.
|
|
||||||
func (e *KafkaExecutor) CreateTopic(topic string) error {
|
|
||||||
cmd := exec.Command(
|
|
||||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
|
|
||||||
"--create",
|
|
||||||
"--zookeeper", e.ZookeeperURI,
|
|
||||||
"--replication-factor", "1",
|
|
||||||
"--partitions", "1",
|
|
||||||
"--topic", topic,
|
|
||||||
)
|
|
||||||
cmd.Stdout = e.OutputWriter
|
|
||||||
cmd.Stderr = e.OutputWriter
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteToTopic writes data to a kafka topic.
|
|
||||||
func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error {
|
|
||||||
cmd := exec.Command(
|
|
||||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"),
|
|
||||||
"--broker-list", e.KafkaURI,
|
|
||||||
"--topic", topic,
|
|
||||||
)
|
|
||||||
cmd.Stdout = e.OutputWriter
|
|
||||||
cmd.Stderr = e.OutputWriter
|
|
||||||
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteTopic deletes a given kafka topic if it exists.
|
|
||||||
func (e *KafkaExecutor) DeleteTopic(topic string) error {
|
|
||||||
cmd := exec.Command(
|
|
||||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
|
|
||||||
"--delete",
|
|
||||||
"--if-exists",
|
|
||||||
"--zookeeper", e.ZookeeperURI,
|
|
||||||
"--topic", topic,
|
|
||||||
)
|
|
||||||
cmd.Stderr = e.OutputWriter
|
|
||||||
cmd.Stdout = e.OutputWriter
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
|
@ -1,152 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Defaulting allows assignment of string variables with a fallback default value
|
|
||||||
// Useful for use with os.Getenv() for example
|
|
||||||
func Defaulting(value, defaultValue string) string {
|
|
||||||
if value == "" {
|
|
||||||
value = defaultValue
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDatabase creates a new database, dropping it first if it exists
|
|
||||||
func CreateDatabase(command string, args []string, database string) error {
|
|
||||||
cmd := exec.Command(command, args...)
|
|
||||||
cmd.Stdin = strings.NewReader(
|
|
||||||
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
|
|
||||||
)
|
|
||||||
// Send stdout and stderr to our stderr so that we see error messages from
|
|
||||||
// the psql process
|
|
||||||
cmd.Stdout = os.Stderr
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateBackgroundCommand creates an executable command
|
|
||||||
// The Cmd being executed is returned. A channel is also returned,
|
|
||||||
// which will have any termination errors sent down it, followed immediately by the channel being closed.
|
|
||||||
func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) {
|
|
||||||
cmd := exec.Command(command, args...)
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
cmd.Stdout = os.Stderr
|
|
||||||
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
panic("failed to start server: " + err.Error())
|
|
||||||
}
|
|
||||||
cmdChan := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
cmdChan <- cmd.Wait()
|
|
||||||
close(cmdChan)
|
|
||||||
}()
|
|
||||||
return cmd, cmdChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitDatabase creates the database and config file needed for the server to run
|
|
||||||
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
|
|
||||||
if len(databases) > 0 {
|
|
||||||
var dbCmd string
|
|
||||||
var dbArgs []string
|
|
||||||
if postgresContainerName == "" {
|
|
||||||
dbCmd = "psql"
|
|
||||||
dbArgs = []string{postgresDatabase}
|
|
||||||
} else {
|
|
||||||
dbCmd = "docker"
|
|
||||||
dbArgs = []string{
|
|
||||||
"exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, database := range databases {
|
|
||||||
if err := CreateDatabase(dbCmd, dbArgs, database); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartProxy creates a reverse proxy
|
|
||||||
func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) {
|
|
||||||
proxyArgs := []string{
|
|
||||||
"--bind-address", bindAddr,
|
|
||||||
"--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect),
|
|
||||||
"--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect),
|
|
||||||
"--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect),
|
|
||||||
"--tls-cert", "server.crt",
|
|
||||||
"--tls-key", "server.key",
|
|
||||||
}
|
|
||||||
return CreateBackgroundCommand(
|
|
||||||
filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"),
|
|
||||||
proxyArgs,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServe will listen on a random high-numbered port and attach the given router.
|
|
||||||
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
|
|
||||||
func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) {
|
|
||||||
listener, err := net.Listen("tcp", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to listen: %s", err)
|
|
||||||
}
|
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
srv := http.Server{}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
srv.Handler = router
|
|
||||||
var err error
|
|
||||||
if useTLS {
|
|
||||||
certFile := filepath.Join(os.TempDir(), "dendrite.cert")
|
|
||||||
keyFile := filepath.Join(os.TempDir(), "dendrite.key")
|
|
||||||
err = NewTLSKey(keyFile, certFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("failed to generate tls key/cert: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = srv.ServeTLS(listener, certFile, keyFile)
|
|
||||||
} else {
|
|
||||||
err = srv.Serve(listener)
|
|
||||||
}
|
|
||||||
if err != nil && err != http.ErrServerClosed {
|
|
||||||
t.Logf("Listen failed: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
secure := ""
|
|
||||||
if useTLS {
|
|
||||||
secure = "s"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("http%s://localhost:%d", secure, port), func() {
|
|
||||||
_ = srv.Shutdown(context.Background())
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -84,7 +84,7 @@ type DeviceListUpdater struct {
|
||||||
db DeviceListUpdaterDatabase
|
db DeviceListUpdaterDatabase
|
||||||
api DeviceListUpdaterAPI
|
api DeviceListUpdaterAPI
|
||||||
producer KeyChangeProducer
|
producer KeyChangeProducer
|
||||||
fedClient fedsenderapi.FederationClient
|
fedClient fedsenderapi.KeyserverFederationAPI
|
||||||
workerChans []chan gomatrixserverlib.ServerName
|
workerChans []chan gomatrixserverlib.ServerName
|
||||||
|
|
||||||
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
||||||
|
@ -127,7 +127,7 @@ type KeyChangeProducer interface {
|
||||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||||
func NewDeviceListUpdater(
|
func NewDeviceListUpdater(
|
||||||
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||||
fedClient fedsenderapi.FederationClient, numWorkers int,
|
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||||
) *DeviceListUpdater {
|
) *DeviceListUpdater {
|
||||||
return &DeviceListUpdater{
|
return &DeviceListUpdater{
|
||||||
userIDToMutex: make(map[string]*sync.Mutex),
|
userIDToMutex: make(map[string]*sync.Mutex),
|
||||||
|
|
|
@ -37,7 +37,7 @@ import (
|
||||||
type KeyInternalAPI struct {
|
type KeyInternalAPI struct {
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
ThisServer gomatrixserverlib.ServerName
|
ThisServer gomatrixserverlib.ServerName
|
||||||
FedClient fedsenderapi.FederationClient
|
FedClient fedsenderapi.KeyserverFederationAPI
|
||||||
UserAPI userapi.KeyserverUserAPI
|
UserAPI userapi.KeyserverUserAPI
|
||||||
Producer *producers.KeyChange
|
Producer *producers.KeyChange
|
||||||
Updater *DeviceListUpdater
|
Updater *DeviceListUpdater
|
||||||
|
|
|
@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
||||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
|
||||||
) api.KeyInternalAPI {
|
) api.KeyInternalAPI {
|
||||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||||
|
|
||||||
|
|
|
@ -183,6 +183,7 @@ type FederationRoomserverAPI interface {
|
||||||
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
||||||
// Query whether a server is allowed to see an event
|
// Query whether a server is allowed to see an event
|
||||||
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
|
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
|
||||||
|
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||||
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
|
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,7 @@ var jc *nats.Conn
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
var b *base.BaseDendrite
|
var b *base.BaseDendrite
|
||||||
b, js, jc = test.Base(nil)
|
b, js, jc = testrig.Base(nil)
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
b.ShutdownDendrite()
|
b.ShutdownDendrite()
|
||||||
b.WaitForComponentsToFinish()
|
b.WaitForComponentsToFinish()
|
||||||
|
|
|
@ -19,8 +19,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/test"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_EventsTable(t *testing.T) {
|
func Test_EventsTable(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
|
|
@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
|
||||||
|
|
||||||
func TestPreviousEventsTable(t *testing.T) {
|
func TestPreviousEventsTable(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
tab, close := mustCreatePreviousEventsTable(t, dbType)
|
tab, close := mustCreatePreviousEventsTable(t, dbType)
|
||||||
|
|
|
@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
|
||||||
|
|
||||||
func TestPublishedTable(t *testing.T) {
|
func TestPublishedTable(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
tab, close := mustCreatePublishedTable(t, dbType)
|
tab, close := mustCreatePublishedTable(t, dbType)
|
||||||
|
|
|
@ -36,7 +36,7 @@ func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.Ro
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoomAliasesTable(t *testing.T) {
|
func TestRoomAliasesTable(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
room2 := test.NewRoom(t, alice)
|
room2 := test.NewRoom(t, alice)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
@ -38,7 +38,7 @@ func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, c
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoomsTable(t *testing.T) {
|
func TestRoomsTable(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
|
|
@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
||||||
|
|
||||||
func TestWriteEvents(t *testing.T) {
|
func TestWriteEvents(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
// dummy room to make sure SQL queries are filtering on room ID
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := MustCreateDatabase(t, dbType)
|
db, close := MustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
r := test.NewRoom(t, alice)
|
r := test.NewRoom(t, alice)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||||
|
|
|
@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
|
||||||
|
|
||||||
func TestOutputRoomEventsTable(t *testing.T) {
|
func TestOutputRoomEventsTable(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||||
|
|
|
@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
|
||||||
|
|
||||||
func TestTopologyTable(t *testing.T) {
|
func TestTopologyTable(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
tab, db, close := newTopologyTable(t, dbType)
|
tab, db, close := newTopologyTable(t, dbType)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
@ -86,7 +87,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
||||||
user := test.NewUser()
|
user := test.NewUser(t)
|
||||||
room := test.NewRoom(t, user)
|
room := test.NewRoom(t, user)
|
||||||
alice := userapi.Device{
|
alice := userapi.Device{
|
||||||
ID: "ALICEID",
|
ID: "ALICEID",
|
||||||
|
@ -96,14 +97,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
||||||
AccountType: userapi.AccountTypeUser,
|
AccountType: userapi.AccountTypeUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
base, close := test.CreateBaseDendrite(t, dbType)
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||||
msgs := toNATSMsgs(t, base, room.Events())
|
msgs := toNATSMsgs(t, base, room.Events())
|
||||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
||||||
test.MustPublishMsgs(t, jsctx, msgs...)
|
testrig.MustPublishMsgs(t, jsctx, msgs...)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -173,7 +174,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||||
user := test.NewUser()
|
user := test.NewUser(t)
|
||||||
room := test.NewRoom(t, user)
|
room := test.NewRoom(t, user)
|
||||||
alice := userapi.Device{
|
alice := userapi.Device{
|
||||||
ID: "ALICEID",
|
ID: "ALICEID",
|
||||||
|
@ -183,7 +184,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||||
AccountType: userapi.AccountTypeUser,
|
AccountType: userapi.AccountTypeUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
base, close := test.CreateBaseDendrite(t, dbType)
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||||
|
@ -198,7 +199,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||||
sinceTokens := make([]string, len(msgs))
|
sinceTokens := make([]string, len(msgs))
|
||||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
||||||
for i, msg := range msgs {
|
for i, msg := range msgs {
|
||||||
test.MustPublishMsgs(t, jsctx, msg)
|
testrig.MustPublishMsgs(t, jsctx, msg)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||||
|
@ -262,7 +263,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
|
||||||
if ev.StateKey() != nil {
|
if ev.StateKey() != nil {
|
||||||
addsStateIDs = append(addsStateIDs, ev.EventID())
|
addsStateIDs = append(addsStateIDs, ev.EventID())
|
||||||
}
|
}
|
||||||
result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
|
result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
|
||||||
Type: rsapi.OutputTypeNewRoomEvent,
|
Type: rsapi.OutputTypeNewRoomEvent,
|
||||||
NewRoomEvent: &rsapi.OutputNewRoomEvent{
|
NewRoomEvent: &rsapi.OutputNewRoomEvent{
|
||||||
Event: ev,
|
Event: ev,
|
||||||
|
|
|
@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier {
|
||||||
|
return func(e *eventMods) {
|
||||||
|
e.keyID = keyID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
|
||||||
|
return func(e *eventMods) {
|
||||||
|
e.privKey = pkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
|
||||||
|
return func(e *eventMods) {
|
||||||
|
e.origin = origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reverse a list of events
|
// Reverse a list of events
|
||||||
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||||
|
|
47
test/http.go
47
test/http.go
|
@ -2,10 +2,15 @@ package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http
|
||||||
}
|
}
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListenAndServe will listen on a random high-numbered port and attach the given router.
|
||||||
|
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
|
||||||
|
func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) {
|
||||||
|
listener, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to listen: %s", err)
|
||||||
|
}
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
srv := http.Server{}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
srv.Handler = router
|
||||||
|
var err error
|
||||||
|
if withTLS {
|
||||||
|
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
|
||||||
|
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
|
||||||
|
err = NewTLSKey(keyFile, certFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to make TLS key: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = srv.ServeTLS(listener, certFile, keyFile)
|
||||||
|
} else {
|
||||||
|
err = srv.Serve(listener)
|
||||||
|
}
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
t.Logf("Listen failed: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s := ""
|
||||||
|
if withTLS {
|
||||||
|
s = "s"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("http%s://localhost:%d", s, port), func() {
|
||||||
|
_ = srv.Shutdown(context.Background())
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -25,103 +25,19 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ConfigFile is the name of the config file for a server.
|
|
||||||
ConfigFile = "dendrite.yaml"
|
|
||||||
// ServerKeyFile is the name of the file holding the matrix server private key.
|
// ServerKeyFile is the name of the file holding the matrix server private key.
|
||||||
ServerKeyFile = "server_key.pem"
|
ServerKeyFile = "server_key.pem"
|
||||||
// TLSCertFile is the name of the file holding the TLS certificate used for federation.
|
// TLSCertFile is the name of the file holding the TLS certificate used for federation.
|
||||||
TLSCertFile = "tls_cert.pem"
|
TLSCertFile = "tls_cert.pem"
|
||||||
// TLSKeyFile is the name of the file holding the TLS key used for federation.
|
// TLSKeyFile is the name of the file holding the TLS key used for federation.
|
||||||
TLSKeyFile = "tls_key.pem"
|
TLSKeyFile = "tls_key.pem"
|
||||||
// MediaDir is the name of the directory used to store media.
|
|
||||||
MediaDir = "media"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeConfig makes a config suitable for running integration tests.
|
|
||||||
// Generates new matrix and TLS keys for the server.
|
|
||||||
func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) {
|
|
||||||
var cfg config.Dendrite
|
|
||||||
cfg.Defaults(true)
|
|
||||||
|
|
||||||
port := startPort
|
|
||||||
assignAddress := func() config.HTTPAddress {
|
|
||||||
result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port))
|
|
||||||
port++
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKeyPath := filepath.Join(configDir, ServerKeyFile)
|
|
||||||
tlsCertPath := filepath.Join(configDir, TLSKeyFile)
|
|
||||||
tlsKeyPath := filepath.Join(configDir, TLSCertFile)
|
|
||||||
mediaBasePath := filepath.Join(configDir, MediaDir)
|
|
||||||
|
|
||||||
if err := NewMatrixKey(serverKeyPath); err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Version = config.Version
|
|
||||||
|
|
||||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress())
|
|
||||||
cfg.Global.PrivateKeyPath = config.Path(serverKeyPath)
|
|
||||||
|
|
||||||
cfg.MediaAPI.BasePath = config.Path(mediaBasePath)
|
|
||||||
|
|
||||||
cfg.Global.JetStream.Addresses = []string{kafkaURI}
|
|
||||||
|
|
||||||
// TODO: Use different databases for the different schemas.
|
|
||||||
// Using the same database for every schema currently works because
|
|
||||||
// the table names are globally unique. But we might not want to
|
|
||||||
// rely on that in the future.
|
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
|
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
|
|
||||||
|
|
||||||
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.FederationAPI.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.KeyServer.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.MediaAPI.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.RoomServer.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.SyncAPI.InternalAPI.Listen = assignAddress()
|
|
||||||
cfg.UserAPI.InternalAPI.Listen = assignAddress()
|
|
||||||
|
|
||||||
cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen
|
|
||||||
cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen
|
|
||||||
cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen
|
|
||||||
cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen
|
|
||||||
cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen
|
|
||||||
cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen
|
|
||||||
cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen
|
|
||||||
|
|
||||||
return &cfg, port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteConfig writes the config file to the directory.
|
|
||||||
func WriteConfig(cfg *config.Dendrite, configDir string) error {
|
|
||||||
data, err := yaml.Marshal(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
|
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
|
||||||
func NewMatrixKey(matrixKeyPath string) (err error) {
|
func NewMatrixKey(matrixKeyPath string) (err error) {
|
||||||
var data [35]byte
|
var data [35]byte
|
54
test/room.go
54
test/room.go
|
@ -15,7 +15,6 @@
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -35,12 +34,6 @@ var (
|
||||||
PresetTrustedPrivateChat Preset = 3
|
PresetTrustedPrivateChat Preset = 3
|
||||||
|
|
||||||
roomIDCounter = int64(0)
|
roomIDCounter = int64(0)
|
||||||
|
|
||||||
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
|
|
||||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
|
||||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
|
||||||
})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Room struct {
|
type Room struct {
|
||||||
|
@ -50,6 +43,7 @@ type Room struct {
|
||||||
creator *User
|
creator *User
|
||||||
|
|
||||||
authEvents gomatrixserverlib.AuthEvents
|
authEvents gomatrixserverlib.AuthEvents
|
||||||
|
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
||||||
events []*gomatrixserverlib.HeaderedEvent
|
events []*gomatrixserverlib.HeaderedEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,14 +51,16 @@ type Room struct {
|
||||||
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
counter := atomic.AddInt64(&roomIDCounter, 1)
|
counter := atomic.AddInt64(&roomIDCounter, 1)
|
||||||
|
if creator.srvName == "" {
|
||||||
// set defaults then let roomModifiers override
|
t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
|
||||||
|
}
|
||||||
r := &Room{
|
r := &Room{
|
||||||
ID: fmt.Sprintf("!%d:localhost", counter),
|
ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
|
||||||
creator: creator,
|
creator: creator,
|
||||||
authEvents: gomatrixserverlib.NewAuthEvents(nil),
|
authEvents: gomatrixserverlib.NewAuthEvents(nil),
|
||||||
preset: PresetPublicChat,
|
preset: PresetPublicChat,
|
||||||
Version: gomatrixserverlib.RoomVersionV9,
|
Version: gomatrixserverlib.RoomVersionV9,
|
||||||
|
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
|
||||||
}
|
}
|
||||||
for _, m := range modifiers {
|
for _, m := range modifiers {
|
||||||
m(t, r)
|
m(t, r)
|
||||||
|
@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference {
|
||||||
|
t.Helper()
|
||||||
|
a, err := needed.AuthEventReferences(&r.authEvents)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MustGetAuthEvents: %v", err)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Room) ForwardExtremities() []string {
|
||||||
|
if len(r.events) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{
|
||||||
|
r.events[len(r.events)-1].EventID(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Room) insertCreateEvents(t *testing.T) {
|
func (r *Room) insertCreateEvents(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var joinRule gomatrixserverlib.JoinRuleContent
|
var joinRule gomatrixserverlib.JoinRuleContent
|
||||||
|
@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
||||||
joinRule.JoinRule = "public"
|
joinRule.JoinRule = "public"
|
||||||
hisVis.HistoryVisibility = "shared"
|
hisVis.HistoryVisibility = "shared"
|
||||||
}
|
}
|
||||||
|
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
|
||||||
"creator": r.creator.ID,
|
"creator": r.creator.ID,
|
||||||
"room_version": r.Version,
|
"room_version": r.Version,
|
||||||
|
@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
|
||||||
}
|
}
|
||||||
|
|
||||||
if mod.privKey == nil {
|
if mod.privKey == nil {
|
||||||
mod.privKey = testPrivateKey
|
mod.privKey = creator.privKey
|
||||||
}
|
}
|
||||||
if mod.keyID == "" {
|
if mod.keyID == "" {
|
||||||
mod.keyID = testKeyID
|
mod.keyID = creator.keyID
|
||||||
}
|
}
|
||||||
if mod.originServerTS.IsZero() {
|
if mod.originServerTS.IsZero() {
|
||||||
mod.originServerTS = time.Now()
|
mod.originServerTS = time.Now()
|
||||||
}
|
}
|
||||||
if mod.origin == "" {
|
if mod.origin == "" {
|
||||||
mod.origin = gomatrixserverlib.ServerName("localhost")
|
mod.origin = creator.srvName
|
||||||
}
|
}
|
||||||
|
|
||||||
var unsigned gomatrixserverlib.RawJSON
|
var unsigned gomatrixserverlib.RawJSON
|
||||||
|
@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
|
||||||
// Add a new event to this room DAG. Not thread-safe.
|
// Add a new event to this room DAG. Not thread-safe.
|
||||||
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
|
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
// Add the event to the list of auth events
|
// Add the event to the list of auth/state events
|
||||||
r.events = append(r.events, he)
|
r.events = append(r.events, he)
|
||||||
if he.StateKey() != nil {
|
if he.StateKey() != nil {
|
||||||
err := r.authEvents.AddEvent(he.Unwrap())
|
err := r.authEvents.AddEvent(he.Unwrap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
|
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
|
||||||
}
|
}
|
||||||
|
r.currentState[he.Type()+" "+*he.StateKey()] = he
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
|
||||||
return r.events
|
return r.events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent {
|
||||||
|
events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState))
|
||||||
|
i := 0
|
||||||
|
for _, e := range r.currentState {
|
||||||
|
events[i] = e
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
he := r.CreateEvent(t, creator, eventType, content, mods...)
|
he := r.CreateEvent(t, creator, eventType, content, mods...)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package test
|
package testrig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -24,22 +24,23 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) {
|
func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) {
|
||||||
var cfg config.Dendrite
|
var cfg config.Dendrite
|
||||||
cfg.Defaults(false)
|
cfg.Defaults(false)
|
||||||
cfg.Global.JetStream.InMemory = true
|
cfg.Global.JetStream.InMemory = true
|
||||||
|
|
||||||
switch dbType {
|
switch dbType {
|
||||||
case DBTypePostgres:
|
case test.DBTypePostgres:
|
||||||
cfg.Global.Defaults(true) // autogen a signing key
|
cfg.Global.Defaults(true) // autogen a signing key
|
||||||
cfg.MediaAPI.Defaults(true) // autogen a media path
|
cfg.MediaAPI.Defaults(true) // autogen a media path
|
||||||
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
||||||
// the file system event with InMemory=true :(
|
// the file system event with InMemory=true :(
|
||||||
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
|
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
|
||||||
connStr, close := PrepareDBConnectionString(t, dbType)
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
cfg.Global.DatabaseOptions = config.DatabaseOptions{
|
cfg.Global.DatabaseOptions = config.DatabaseOptions{
|
||||||
ConnectionString: config.DataSource(connStr),
|
ConnectionString: config.DataSource(connStr),
|
||||||
MaxOpenConnections: 10,
|
MaxOpenConnections: 10,
|
||||||
|
@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()
|
||||||
ConnMaxLifetimeSeconds: 60,
|
ConnMaxLifetimeSeconds: 60,
|
||||||
}
|
}
|
||||||
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
|
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
|
||||||
case DBTypeSQLite:
|
case test.DBTypeSQLite:
|
||||||
cfg.Defaults(true) // sets a sqlite db per component
|
cfg.Defaults(true) // sets a sqlite db per component
|
||||||
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
||||||
// the file system event with InMemory=true :(
|
// the file system event with InMemory=true :(
|
|
@ -1,4 +1,4 @@
|
||||||
package test
|
package testrig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
50
test/user.go
50
test/user.go
|
@ -15,22 +15,64 @@
|
||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userIDCounter = int64(0)
|
userIDCounter = int64(0)
|
||||||
|
|
||||||
|
serverName = gomatrixserverlib.ServerName("test")
|
||||||
|
keyID = gomatrixserverlib.KeyID("ed25519:test")
|
||||||
|
privateKey = ed25519.NewKeyFromSeed([]byte{
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||||
|
})
|
||||||
|
|
||||||
|
// private keys that tests can use
|
||||||
|
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
|
||||||
|
})
|
||||||
|
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string
|
ID string
|
||||||
|
// key ID and private key of the server who has this user, if known.
|
||||||
|
keyID gomatrixserverlib.KeyID
|
||||||
|
privKey ed25519.PrivateKey
|
||||||
|
srvName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUser() *User {
|
type UserOpt func(*User)
|
||||||
|
|
||||||
|
func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
|
||||||
|
return func(u *User) {
|
||||||
|
u.keyID = keyID
|
||||||
|
u.privKey = privKey
|
||||||
|
u.srvName = srvName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUser(t *testing.T, opts ...UserOpt) *User {
|
||||||
counter := atomic.AddInt64(&userIDCounter, 1)
|
counter := atomic.AddInt64(&userIDCounter, 1)
|
||||||
u := &User{
|
var u User
|
||||||
ID: fmt.Sprintf("@%d:localhost", counter),
|
for _, opt := range opts {
|
||||||
|
opt(&u)
|
||||||
}
|
}
|
||||||
return u
|
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
|
||||||
|
t.Logf("NewUser: missing signing server credentials; using default.")
|
||||||
|
WithSigningServer(serverName, keyID, privateKey)(&u)
|
||||||
|
}
|
||||||
|
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
|
||||||
|
t.Logf("NewUser: created user %s", u.ID)
|
||||||
|
return &u
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Devices(t *testing.T) {
|
func Test_Devices(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
deviceID := util.RandomString(8)
|
deviceID := util.RandomString(8)
|
||||||
|
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_KeyBackup(t *testing.T) {
|
func Test_KeyBackup(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginToken(t *testing.T) {
|
func Test_LoginToken(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_OpenID(t *testing.T) {
|
func Test_OpenID(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
token := util.RandomString(24)
|
token := util.RandomString(24)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Profile(t *testing.T) {
|
func Test_Profile(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Pusher(t *testing.T) {
|
func Test_Pusher(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ThreePID(t *testing.T) {
|
func Test_ThreePID(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Notification(t *testing.T) {
|
func Test_Notification(t *testing.T) {
|
||||||
alice := test.NewUser()
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
|
|
|
@ -24,7 +24,6 @@ import (
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
internalTest "github.com/matrix-org/dendrite/internal/test"
|
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/userapi"
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
|
@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
t.Run("HTTP API", func(t *testing.T) {
|
t.Run("HTTP API", func(t *testing.T) {
|
||||||
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||||
userapi.AddInternalRoutes(router, userAPI)
|
userapi.AddInternalRoutes(router, userAPI)
|
||||||
apiURL, cancel := internalTest.ListenAndServe(t, router, false)
|
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
|
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue