Refactor notifications (#2688)

This PR changes the handling of notifications
- removes the `StreamEvent` and `ReadUpdate` stream
- listens on the `OutputRoomEvent` stream in the UserAPI to inform the
SyncAPI about unread notifications
- listens on the `OutputReceiptEvent` stream in the UserAPI to set
receipts/update notifications
- sets the `read_markers` directly from within the internal UserAPI

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Till 2022-09-27 15:01:34 +02:00 committed by GitHub
parent f18bce93cc
commit 249b32c4f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 368 additions and 598 deletions

View file

@ -9,9 +9,10 @@ import (
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/server"
natsclient "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go"
@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"},
OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"},
OutputRoomEvent: {"AppserviceRoomserverConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"},
OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"},
OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"},
} { } {
streamName := cfg.Matrix.JetStream.Prefixed(stream) streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers { for _, consumer := range consumers {

View file

@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,
Storage: nats.FileStorage, Storage: nats.FileStorage,
}, },
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{ {
Name: OutputPresenceEvent, Name: OutputPresenceEvent,
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,

View file

@ -16,9 +16,7 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -31,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
@ -46,7 +43,6 @@ type OutputClientDataConsumer struct {
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -57,7 +53,6 @@ func NewOutputClientDataConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -68,7 +63,6 @@ func NewOutputClientDataConsumer(
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -113,15 +107,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false return false
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
if output.IgnoredUsers != nil { if output.IgnoredUsers != nil {
if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil {
log.WithError(err).WithFields(logrus.Fields{ log.WithError(err).WithFields(logrus.Fields{
@ -136,34 +121,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -16,22 +16,19 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"fmt"
"strconv" "strconv"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
@ -44,7 +41,6 @@ type OutputReceiptEventConsumer struct {
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -56,7 +52,6 @@ func NewOutputReceiptEventConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -67,7 +62,6 @@ func NewOutputReceiptEventConsumer(
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -111,42 +105,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return true return true
} }
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
} }
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -21,17 +21,17 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -46,7 +46,6 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -59,7 +58,6 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -72,7 +70,6 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer,
} }
} }
@ -255,12 +252,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil return nil
} }
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -19,6 +19,9 @@ import (
"encoding/json" "encoding/json"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -26,8 +29,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputNotificationDataConsumer consumes events that originated in // OutputNotificationDataConsumer consumes events that originated in

View file

@ -1,62 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -1,60 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -29,6 +29,7 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -148,12 +149,6 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
@ -179,3 +174,11 @@ type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID. // SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
} }
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,14 +35,14 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{} r := &notificationDataStatements{}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
} }
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id = ANY($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
} }
r := &notificationDataStatements{ r := &notificationDataStatements{
streamIDStatements: streamID, streamIDStatements: streamID,
db: db,
} }
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
//selectUserUnreadCountsForRooms *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id IN ($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1)
rows, err := r.db.QueryContext(ctx, sql, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -190,7 +190,7 @@ type Memberships interface {
type NotificationData interface { type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
} }

View file

@ -3,9 +3,10 @@ package streams
import ( import (
"context" "context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataStreamProvider struct { type AccountDataStreamProvider struct {

View file

@ -30,26 +30,29 @@ func (p *NotificationDataStreamProvider) CompleteSync(
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero. // Get the unread notifications for rooms in our join response.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) // This is to ensure clients always have an unread notification section
// and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from
} }
// We're merely decorating existing rooms. Note that the Join map // We're merely decorating existing rooms.
// values are not pointers.
for roomID, jr := range req.Response.Rooms.Join { for roomID, jr := range req.Response.Rooms.Join {
counts := countsByRoom[roomID] counts := countsByRoom[roomID]
if counts == nil { if counts == nil {
continue continue
} }
jr.UnreadNotifications = &types.UnreadNotifications{
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount HighlightCount: counts.UnreadHighlightCount,
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount NotificationCount: counts.UnreadNotificationCount,
}
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }
return to
return p.LatestPosition(ctx)
} }

View file

@ -77,16 +77,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start presence consumer") logrus.WithError(err).Panicf("failed to start presence consumer")
} }
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
js, rsAPI, syncDB, notifier, js, rsAPI, syncDB, notifier,
@ -98,7 +88,7 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, streams.InviteStreamProvider, rsAPI,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")
@ -106,7 +96,6 @@ func AddPublicRoutes(
clientConsumer := consumers.NewOutputClientDataConsumer( clientConsumer := consumers.NewOutputClientDataConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
userAPIReadUpdateProducer,
) )
if err = clientConsumer.Start(); err != nil { if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
@ -135,7 +124,6 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
userAPIReadUpdateProducer,
) )
if err = receiptConsumer.Start(); err != nil { if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer") logrus.WithError(err).Panicf("failed to start receipts consumer")

View file

@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool {
len(r.ToDevice.Events) == 0 len(r.ToDevice.Events) == 0
} }
type UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
}
// JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
Summary struct { Summary struct {
@ -419,10 +424,7 @@ type JoinResponse struct {
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"` } `json:"account_data"`
UnreadNotifications struct { *UnreadNotifications `json:"unread_notifications,omitempty"`
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
} `json:"unread_notifications"`
} }
// NewJoinResponse creates an empty response with initialised arrays. // NewJoinResponse creates an empty response with initialised arrays.
@ -503,19 +505,6 @@ type Peek struct {
Deleted bool Deleted bool
} }
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}
// OutputReceiptEvent is an entry in the receipt output kafka log // OutputReceiptEvent is an entry in the receipt output kafka log
type OutputReceiptEvent struct { type OutputReceiptEvent struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`

View file

@ -0,0 +1,127 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/util"
)
// OutputReceiptEventConsumer consumes events that originated in the clientAPI.
type OutputReceiptEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
jetstream: js,
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"),
db: store,
serverName: cfg.Matrix.ServerName,
syncProducer: syncProducer,
pgClient: pgClient,
}
}
// Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
roomID := msg.Header.Get(jetstream.RoomID)
readPos := msg.Header.Get(jetstream.EventID)
evType := msg.Header.Get("type")
if readPos == "" || evType != "m.read" {
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.serverName {
return true
}
metadata, err := msg.Metadata()
if err != nil {
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
return true
}

View file

@ -26,7 +26,7 @@ import (
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
) )
type OutputStreamEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.UserAPI cfg *config.UserAPI
rsAPI rsapi.UserRoomserverAPI rsAPI rsapi.UserRoomserverAPI
@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct {
syncProducer *producers.SyncAPI syncProducer *producers.SyncAPI
} }
func NewOutputStreamEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.UserAPI, cfg *config.UserAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer(
pgClient pushgateway.Client, pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI, rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI, syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer { ) *OutputRoomEventConsumer {
return &OutputStreamEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
db: store, db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
pgClient: pgClient, pgClient: pgClient,
rsAPI: rsAPI, rsAPI: rsAPI,
syncProducer: syncProducer, syncProducer: syncProducer,
} }
} }
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1, s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error {
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output rsapi.OutputEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return true return true
} }
if output.Event.Event == nil { if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event
if event == nil {
log.Errorf("userapi consumer: expected event") log.Errorf("userapi consumer: expected event")
return true return true
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
"event_type": output.Event.Type(), "event_type": event.Type(),
"stream_pos": output.StreamPosition, }).Tracef("Received message from roomserver: %#v", output)
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { metadata, err := msg.Metadata()
if err != nil {
return true
}
if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure") }).WithError(err).Errorf("userapi consumer: process room event failure")
} }
return true return true
} }
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err) return fmt.Errorf("s.localRoomMembers: %w", err)
@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
// removing it means we can send all notifications to // removing it means we can send all notifications to
// e.g. Element's Push gateway in one go. // e.g. Element's Push gateway in one go.
for _, mem := range members { for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user") }).WithError(err).Error("Unable to push to local user")
continue continue
} }
} }
@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership,
// localRoomMembers fetches the current local members of a room, and // localRoomMembers fetches the current local members of a room, and
// the total number of members. // the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{ req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID, RoomID: roomID,
JoinedOnly: true, JoinedOnly: true,
@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID
// looks it up in roomserver. If there is no name, // looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the // m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name. // room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName { if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event) name, err := unmarshalRoomName(event)
if err != nil { if err != nil {
@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
} }
// notifyLocal finds the right push actions for a local user, given an event. // notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return err
@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"event_id": event.EventID(), "event_id": event.EventID(),
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).Debugf("Push rule evaluation rejected the event") }).Tracef("Push rule evaluation rejected the event")
return nil return nil
} }
@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err return err
} }
@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"localpart": mem.Localpart, "localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat), "num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs, "num_unread": userNumUnreadNotifs,
}).Debugf("Notifying single member") }).Trace("Notifying single member")
// Push gateways are out of our control, and we cannot risk // Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user // looking up the server on a misbehaving push gateway. Each user
@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
// evaluatePushRules fetches and evaluates the push rules of a local // evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID { if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for // SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves. // events that the user has sent themselves.
@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
"rule_id": rule.RuleID, "rule_id": rule.RuleID,
}).Tracef("Matched a push rule") }).Trace("Matched a push rule")
return rule.Actions, nil return rule.Actions, nil
} }
@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp
} }
// notifyHTTP performs a notificatation to a Push Gateway. // notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"url": url, "url": url,
@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
} }
logger.Debugf("Notifying push gateway %s", url) logger.Tracef("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url) logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err return nil, err
} }
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result")
if len(res.Rejected) == 0 { if len(res.Rejected) == 0 {
return nil, nil return nil, nil
@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,

View file

@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
consumer := OutputStreamEventConsumer{db: db} consumer := OutputRoomEventConsumer{db: db}
testCases := []struct { testCases := []struct {
name string name string

View file

@ -1,137 +0,0 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -39,6 +40,7 @@ import (
"github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
userapiUtil "github.com/matrix-org/dendrite/userapi/util"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
@ -51,6 +53,7 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
} }
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
ignoredUsers = &synctypes.IgnoredUsers{} ignoredUsers = &synctypes.IgnoredUsers{}
_ = json.Unmarshal(req.AccountData, ignoredUsers) _ = json.Unmarshal(req.AccountData, ignoredUsers)
} }
if req.DataType == "m.fully_read" {
if err := a.setFullyRead(ctx, req); err != nil {
return err
}
}
if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
RoomID: req.RoomID, RoomID: req.RoomID,
Type: req.DataType, Type: req.DataType,
@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
return nil return nil
} }
func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error {
var output eventutil.ReadMarkerJSON
if err := json.Unmarshal(req.AccountData, &output); err != nil {
return err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil
}
if domain != a.ServerName {
return nil
}
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
}
if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed")
return err
}
// nothing changed, no need to notify the push gateway
if !deleted {
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
return nil
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {

View file

@ -4,12 +4,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
) )
type JetStreamPublisher interface { type JetStreamPublisher interface {

View file

@ -119,9 +119,9 @@ type ThreePID interface {
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
} }
return count, nil func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
} err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return 0, rows.Err() return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }

View file

@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err return err
@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err return err

View file

@ -20,12 +20,13 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
) )
type notificationsStatements struct { type notificationsStatements struct {
@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
return
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
} }
return count, nil func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
} err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
return 0, rows.Err() return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
} }

View file

@ -19,11 +19,12 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers(
pushers = append(pushers, pusher) pushers = append(pushers, pusher)
} }
logrus.Debugf("Database returned %d pushers", len(pushers)) logrus.Tracef("Database returned %d pushers", len(pushers))
return pushers, rows.Err() return pushers, rows.Err()
} }

View file

@ -7,6 +7,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -14,10 +19,6 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
) )
const loginTokenLifetime = time.Minute const loginTokenLifetime = time.Minute
@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
} }

View file

@ -105,9 +105,9 @@ type PusherTable interface {
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -81,16 +81,17 @@ func NewInternalAPI(
KeyAPI: keyAPI, KeyAPI: keyAPI,
RSAPI: rsAPI, RSAPI: rsAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
PgClient: pgClient,
} }
readConsumer := consumers.NewOutputReadUpdateConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, base.ProcessContext, cfg, js, db, syncProducer, pgClient,
) )
if err := readConsumer.Start(); err != nil { if err := receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API read update consumer") logrus.WithError(err).Panic("failed to start user API receipt consumer")
} }
eventConsumer := consumers.NewOutputStreamEventConsumer( eventConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer,
) )
if err := eventConsumer.Start(); err != nil { if err := eventConsumer.Start(); err != nil {