Add Events Table tests

Move variable declaration outside loops
Switch to testify/assert for tests
This commit is contained in:
Till Faelligen 2022-05-09 09:04:51 +02:00
parent e2751781e7
commit a4a20945cc
9 changed files with 174 additions and 123 deletions

View file

@ -155,12 +155,12 @@ type eventStatements struct {
selectRoomNIDsForEventNIDsStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func createEventsTable(db *sql.DB) error { func CreateEventsTable(db *sql.DB) error {
_, err := db.Exec(eventsSchema) _, err := db.Exec(eventsSchema)
return err return err
} }
func prepareEventsTable(db *sql.DB) (tables.Events, error) { func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{} s := &eventStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
@ -380,15 +380,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
results := make([]types.StateAtEventAndReference, len(eventNIDs)) results := make([]types.StateAtEventAndReference, len(eventNIDs))
i := 0 i := 0
var (
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
stateSnapshotNID int64
eventID string
eventSHA256 []byte
)
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var (
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
stateSnapshotNID int64
eventID string
eventSHA256 []byte
)
if err = rows.Scan( if err = rows.Scan(
&eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
); err != nil { ); err != nil {
@ -446,9 +446,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed")
results := make(map[types.EventNID]string, len(eventNIDs)) results := make(map[types.EventNID]string, len(eventNIDs))
i := 0 i := 0
var eventNID int64
var eventID string!!¹23456789!!"§$%"
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var eventNID int64
var eventID string
if err = rows.Scan(&eventNID, &eventID); err != nil { if err = rows.Scan(&eventNID, &eventID); err != nil {
return nil, err return nil, err
} }
@ -491,9 +491,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs)) results := make(map[string]types.EventNID, len(eventIDs))
var eventID string
var eventNID int64
for rows.Next() { for rows.Next() {
var eventID string
var eventNID int64
if err = rows.Scan(&eventID, &eventNID); err != nil { if err = rows.Scan(&eventID, &eventNID); err != nil {
return nil, err return nil, err
} }
@ -522,9 +522,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID) result := make(map[types.EventNID]types.RoomNID)
var eventNID types.EventNID
var roomNID types.RoomNID
for rows.Next() { for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil { if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err return nil, err
} }

View file

@ -77,7 +77,7 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventJSONTable(db); err != nil { if err := CreateEventJSONTable(db); err != nil {
return err return err
} }
if err := createEventsTable(db); err != nil { if err := CreateEventsTable(db); err != nil {
return err return err
} }
if err := createRoomsTable(db); err != nil { if err := createRoomsTable(db); err != nil {
@ -124,7 +124,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
events, err := prepareEventsTable(db) events, err := PrepareEventsTable(db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -68,7 +68,8 @@ const bulkSelectStateEventByIDSQL = "" +
const bulkSelectStateEventByNIDSQL = "" + const bulkSelectStateEventByNIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_nid IN ($1)" " WHERE event_nid IN ($1)"
// Rest of query is built by BulkSelectStateEventByNID
// Rest of query is built by BulkSelectStateEventByNID
const bulkSelectStateAtEventByIDSQL = "" + const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
@ -126,12 +127,12 @@ type eventStatements struct {
//selectRoomNIDsForEventNIDsStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt
} }
func createEventsTable(db *sql.DB) error { func CreateEventsTable(db *sql.DB) error {
_, err := db.Exec(eventsSchema) _, err := db.Exec(eventsSchema)
return err return err
} }
func prepareEventsTable(db *sql.DB) (tables.Events, error) { func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{ s := &eventStatements{
db: db, db: db,
} }
@ -404,15 +405,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
results := make([]types.StateAtEventAndReference, len(eventNIDs)) results := make([]types.StateAtEventAndReference, len(eventNIDs))
i := 0 i := 0
var (
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
stateSnapshotNID int64
eventID string
eventSHA256 []byte
)
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var (
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
stateSnapshotNID int64
eventID string
eventSHA256 []byte
)
if err = rows.Scan( if err = rows.Scan(
&eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
); err != nil { ); err != nil {
@ -491,9 +492,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed")
results := make(map[types.EventNID]string, len(eventNIDs)) results := make(map[types.EventNID]string, len(eventNIDs))
i := 0 i := 0
var eventNID int64
var eventID string
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var eventNID int64
var eventID string
if err = rows.Scan(&eventNID, &eventID); err != nil { if err = rows.Scan(&eventNID, &eventID); err != nil {
return nil, err return nil, err
} }
@ -545,9 +546,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs)) results := make(map[string]types.EventNID, len(eventIDs))
var eventID string
var eventNID int64
for rows.Next() { for rows.Next() {
var eventID string
var eventNID int64
if err = rows.Scan(&eventID, &eventNID); err != nil { if err = rows.Scan(&eventID, &eventNID); err != nil {
return nil, err return nil, err
} }
@ -595,9 +596,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
result := make(map[types.EventNID]types.RoomNID) result := make(map[types.EventNID]types.RoomNID)
var eventNID types.EventNID
var roomNID types.RoomNID
for rows.Next() { for rows.Next() {
var eventNID types.EventNID
var roomNID types.RoomNID
if err = rows.Scan(&eventNID, &roomNID); err != nil { if err = rows.Scan(&eventNID, &roomNID); err != nil {
return nil, err return nil, err
} }

View file

@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventJSONTable(db); err != nil { if err := CreateEventJSONTable(db); err != nil {
return err return err
} }
if err := createEventsTable(db); err != nil { if err := CreateEventsTable(db); err != nil {
return err return err
} }
if err := createRoomsTable(db); err != nil { if err := createRoomsTable(db); err != nil {
@ -133,7 +133,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
events, err := prepareEventsTable(db) events, err := PrepareEventsTable(db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,7 +3,6 @@ package tables_test
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -13,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
) )
func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSON, func()) { func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSON, func()) {
@ -21,27 +21,19 @@ func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSO
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter()) }, sqlutil.NewExclusiveWriter())
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to open db: %s", err)
}
var tab tables.EventJSON var tab tables.EventJSON
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
err = postgres.CreateEventJSONTable(db) err = postgres.CreateEventJSONTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = postgres.PrepareEventJSONTable(db) tab, err = postgres.PrepareEventJSONTable(db)
case test.DBTypeSQLite: case test.DBTypeSQLite:
err = sqlite3.CreateEventJSONTable(db) err = sqlite3.CreateEventJSONTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = sqlite3.PrepareEventJSONTable(db) tab, err = sqlite3.PrepareEventJSONTable(db)
} }
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create table: %s", err)
}
return tab, close return tab, close
} }
@ -52,29 +44,19 @@ func Test_EventJSONTable(t *testing.T) {
defer close() defer close()
// create some dummy data // create some dummy data
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if err := tab.InsertEventJSON( err := tab.InsertEventJSON(
context.Background(), nil, types.EventNID(i), context.Background(), nil, types.EventNID(i),
[]byte(fmt.Sprintf(`{"value":%d"}`, i)), []byte(fmt.Sprintf(`{"value":%d"}`, i)),
); err != nil { )
t.Fatalf("unable to insert eventJSON: %s", err) assert.NoError(t, err)
}
} }
// select a subset of the data // select a subset of the data
values, err := tab.BulkSelectEventJSON(context.Background(), nil, []types.EventNID{1, 2, 3, 4, 5}) values, err := tab.BulkSelectEventJSON(context.Background(), nil, []types.EventNID{1, 2, 3, 4, 5})
if err != nil { assert.NoError(t, err)
t.Fatalf("unable to query eventJSON: %s", err) assert.Equal(t, 5, len(values))
}
if len(values) != 5 {
t.Fatalf("expected 5 events, got %d", len(values))
}
for i, v := range values { for i, v := range values {
if v.EventNID != types.EventNID(i+1) { assert.Equal(t, v.EventNID, types.EventNID(i+1))
t.Fatalf("expected eventNID %d, got %d", i+1, v.EventNID) assert.Equal(t, []byte(fmt.Sprintf(`{"value":%d"}`, i+1)), v.EventJSON)
}
wantValue := []byte(fmt.Sprintf(`{"value":%d"}`, i+1))
if !reflect.DeepEqual(wantValue, v.EventJSON) {
t.Fatalf("expected JSON to be %s, got %s", string(wantValue), string(v.EventJSON))
}
} }
}) })
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
) )
func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.EventStateKeys, func()) { func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.EventStateKeys, func()) {
@ -20,27 +21,19 @@ func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.Eve
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter()) }, sqlutil.NewExclusiveWriter())
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to open db: %s", err)
}
var tab tables.EventStateKeys var tab tables.EventStateKeys
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
err = postgres.CreateEventStateKeysTable(db) err = postgres.CreateEventStateKeysTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = postgres.PrepareEventStateKeysTable(db) tab, err = postgres.PrepareEventStateKeysTable(db)
case test.DBTypeSQLite: case test.DBTypeSQLite:
err = sqlite3.CreateEventStateKeysTable(db) err = sqlite3.CreateEventStateKeysTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = sqlite3.PrepareEventStateKeysTable(db) tab, err = sqlite3.PrepareEventStateKeysTable(db)
} }
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create table: %s", err)
}
return tab, close return tab, close
} }
@ -55,37 +48,26 @@ func Test_EventStateKeysTable(t *testing.T) {
// create some dummy data // create some dummy data
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
stateKey := fmt.Sprintf("@user%d:localhost", i) stateKey := fmt.Sprintf("@user%d:localhost", i)
if stateKeyNID, err = tab.InsertEventStateKeyNID( stateKeyNID, err = tab.InsertEventStateKeyNID(
ctx, nil, stateKey, ctx, nil, stateKey,
); err != nil { )
t.Fatalf("unable to insert eventJSON: %s", err) assert.NoError(t, err)
}
gotEventStateKey, err = tab.SelectEventStateKeyNID(ctx, nil, stateKey) gotEventStateKey, err = tab.SelectEventStateKeyNID(ctx, nil, stateKey)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to get eventStateKeyNID: %s", err) assert.Equal(t, stateKeyNID, gotEventStateKey)
}
if stateKeyNID != gotEventStateKey {
t.Fatalf("expected eventStateKey %d, but got %d", stateKeyNID, gotEventStateKey)
}
} }
stateKeyNIDsMap, err := tab.BulkSelectEventStateKeyNID(ctx, nil, []string{"@user0:localhost", "@user1:localhost"}) stateKeyNIDsMap, err := tab.BulkSelectEventStateKeyNID(ctx, nil, []string{"@user0:localhost", "@user1:localhost"})
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to get EventStateKeyNIDs: %s", err)
}
wantStateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDsMap)) wantStateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDsMap))
for _, nid := range stateKeyNIDsMap { for _, nid := range stateKeyNIDsMap {
wantStateKeyNIDs = append(wantStateKeyNIDs, nid) wantStateKeyNIDs = append(wantStateKeyNIDs, nid)
} }
stateKeyNIDs, err := tab.BulkSelectEventStateKey(ctx, nil, wantStateKeyNIDs) stateKeyNIDs, err := tab.BulkSelectEventStateKey(ctx, nil, wantStateKeyNIDs)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to get EventStateKeyNIDs: %s", err)
}
// verify that BulkSelectEventStateKeyNID and BulkSelectEventStateKey return the same values // verify that BulkSelectEventStateKeyNID and BulkSelectEventStateKey return the same values
for userID, nid := range stateKeyNIDsMap { for userID, nid := range stateKeyNIDsMap {
if v, ok := stateKeyNIDs[nid]; ok { if v, ok := stateKeyNIDs[nid]; ok {
if v != userID { assert.Equal(t, v, userID)
t.Fatalf("userID does not match: %s != %s", userID, v)
}
} else { } else {
t.Fatalf("unable to find %d in result set", nid) t.Fatalf("unable to find %d in result set", nid)
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
) )
func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) { func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) {
@ -20,27 +21,19 @@ func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTy
db, err := sqlutil.Open(&config.DatabaseOptions{ db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter()) }, sqlutil.NewExclusiveWriter())
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to open db: %s", err)
}
var tab tables.EventTypes var tab tables.EventTypes
switch dbType { switch dbType {
case test.DBTypePostgres: case test.DBTypePostgres:
err = postgres.CreateEventTypesTable(db) err = postgres.CreateEventTypesTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = postgres.PrepareEventTypesTable(db) tab, err = postgres.PrepareEventTypesTable(db)
case test.DBTypeSQLite: case test.DBTypeSQLite:
err = sqlite3.CreateEventTypesTable(db) err = sqlite3.CreateEventTypesTable(db)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create EventJSON table: %s", err)
}
tab, err = sqlite3.PrepareEventTypesTable(db) tab, err = sqlite3.PrepareEventTypesTable(db)
} }
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to create table: %s", err)
}
return tab, close return tab, close
} }
@ -63,23 +56,15 @@ func Test_EventTypesTable(t *testing.T) {
} }
eventTypeMap[eventType] = eventTypeNID eventTypeMap[eventType] = eventTypeNID
gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType) gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType)
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to get EventTypeNID: %s", err) assert.Equal(t, eventTypeNID, gotEventTypeNID)
}
if eventTypeNID != gotEventTypeNID {
t.Fatalf("expected eventTypeNID %d, but got %d", eventTypeNID, gotEventTypeNID)
}
} }
eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"}) eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"})
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to get EventStateKeyNIDs: %s", err)
}
// verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values // verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values
for eventType, nid := range eventTypeNIDs { for eventType, nid := range eventTypeNIDs {
if v, ok := eventTypeMap[eventType]; ok { if v, ok := eventTypeMap[eventType]; ok {
if v != nid { assert.Equal(t, v, nid)
t.Fatalf("EventTypeNID does not match: %d != %d", nid, v)
}
} else { } else {
t.Fatalf("unable to find %d in result set", nid) t.Fatalf("unable to find %d in result set", nid)
} }

View file

@ -0,0 +1,100 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.Events
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateEventsTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareEventsTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateEventsTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareEventsTable(db)
}
assert.NoError(t, err)
return tab, close
}
func Test_EventsTable(t *testing.T) {
alice := test.NewUser()
room := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateEventsTable(t, dbType)
defer close()
// create some dummy data
eventIDs := make([]string, 0, len(room.Events()))
wantStateAtEvent := make([]types.StateAtEvent, 0, len(room.Events()))
for _, ev := range room.Events() {
eventIDs = append(eventIDs, ev.EventID())
eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), []byte(""), nil, 0, false)
assert.NoError(t, err)
gotEventNID, gotSnapNID, err := tab.SelectEvent(ctx, nil, ev.EventID())
assert.NoError(t, err)
assert.Equal(t, eventNID, gotEventNID)
assert.Equal(t, snapNID, gotSnapNID)
eventID, err := tab.SelectEventID(ctx, nil, eventNID)
assert.NoError(t, err)
assert.Equal(t, eventID, ev.EventID())
wantStateAtEvent = append(wantStateAtEvent, types.StateAtEvent{
Overwrite: false,
BeforeStateSnapshotNID: 0,
IsRejected: false,
StateEntry: types.StateEntry{
EventNID: eventNID,
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: 1,
EventStateKeyNID: 1,
},
},
})
}
stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs)
assert.NoError(t, err)
assert.Equal(t, len(stateEvents), len(eventIDs))
nids := make([]types.EventNID, 0, len(stateEvents))
for _, ev := range stateEvents {
nids = append(nids, ev.EventNID)
}
stateEvents2, err := tab.BulkSelectStateEventByNID(ctx, nil, nids, nil)
assert.NoError(t, err)
// somehow SQLite doesn't return the values ordered as requested by the query
assert.ElementsMatch(t, stateEvents, stateEvents2)
stateAtEvent, err := tab.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
assert.NoError(t, err)
assert.Equal(t, len(eventIDs), len(stateAtEvent))
assert.ElementsMatch(t, wantStateAtEvent, stateAtEvent)
evendNIDMap, err := tab.BulkSelectEventID(ctx, nil, nids)
assert.NoError(t, err)
t.Logf("%+v", evendNIDMap)
assert.Equal(t, len(evendNIDMap), len(nids))
})
}

View file

@ -35,7 +35,8 @@ type EventStateKeys interface {
type Events interface { type Events interface {
InsertEvent( InsertEvent(
ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID, eventID string,
referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error) ) (types.EventNID, types.StateSnapshotNID, error)
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)