dendrite/syncapi/storage/sqlite3/stream_id_table.go
kegsay 6d25bd6ca5
syncapi: add more tests; fix more bugs (#2338)
* syncapi: add more tests; fix more bugs

bugfixes:
 - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn
 - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice
 - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query.

Added tests:
 - `TestGetEventsInRangeWithTopologyToken`
 - `TestOutputRoomEventsTable`
 - `TestTopologyTable`

* -p 1 for now
2022-04-08 17:53:24 +01:00

80 lines
2.6 KiB
Go

package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
)
const streamIDTableSchema = `
-- Global stream ID counter, used by other tables.
CREATE TABLE IF NOT EXISTS syncapi_stream_id (
stream_name TEXT NOT NULL PRIMARY KEY,
stream_id INT DEFAULT 0,
UNIQUE(stream_name)
);
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0)
ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id"
type StreamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
}
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
}
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
return
}
return
}
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return
}