diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 08e912a00..9bf438a23 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) + eventStateKeys = util.UniqueStrings(eventStateKeys) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) if err != nil { return nil, err @@ -112,15 +113,23 @@ func (d *Database) eventStateKeyNIDs( } // We received some nids, but are still missing some, work out which and create them if len(eventStateKeys) > len(result) { - for _, eventStateKey := range eventStateKeys { - if _, ok := result[eventStateKey]; ok { - continue + var nid types.EventStateKeyNID + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + + nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return err + } + result[eventStateKey] = nid } - nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey) - if err != nil { - return result, err - } - result[eventStateKey] = nid + return nil + }) + if err != nil { + return nil, err } } return result, nil