Change cryptoid references from pseudoids

This commit is contained in:
Devon Hudson 2023-11-17 17:34:01 -07:00
parent 3cbccb9ed7
commit b45e72830e
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
27 changed files with 219 additions and 214 deletions

View file

@ -100,7 +100,7 @@ type queryKeysRequest struct {
type uploadKeysCryptoIDsRequest struct {
DeviceKeys json.RawMessage `json:"device_keys"`
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"`
OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
}
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
@ -132,11 +132,11 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
},
}
}
if r.OneTimePseudoIDs != nil {
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{
if r.OneTimeCryptoIDs != nil {
uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
{
UserID: device.UserID,
KeyJSON: r.OneTimePseudoIDs,
KeyJSON: r.OneTimeCryptoIDs,
},
}
}
@ -144,7 +144,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
util.GetLogger(req.Context()).
WithField("device keys", r.DeviceKeys).
WithField("one-time keys", r.OneTimeKeys).
WithField("one-time pseudoids", r.OneTimePseudoIDs).
WithField("one-time cryptoids", r.OneTimeCryptoIDs).
Info("Uploading keys")
var uploadRes api.PerformUploadKeysResponse
@ -170,16 +170,16 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
if len(uploadRes.OneTimeKeyCounts) > 0 {
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
}
pseudoIDCount := make(map[string]int)
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
cryptoIDCount := make(map[string]int)
if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
}
return util.JSONResponse{
Code: 200,
JSON: struct {
OTKCounts interface{} `json:"one_time_key_counts"`
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"`
}{keyCount, pseudoIDCount},
OTKCounts interface{} `json:"one_time_key_counts"`
OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
}{keyCount, cryptoIDCount},
}
}

View file

@ -320,7 +320,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/sendPDUs")
logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}

View file

@ -100,7 +100,7 @@ func SendEvent(
}
// Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil {
return util.JSONResponse{
@ -154,7 +154,7 @@ func SendEvent(
}
// for power level events we need to replace the userID with the pseudoID
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil {
return util.JSONResponse{
@ -299,7 +299,7 @@ func SendEventCryptoIDs(
}
// Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil {
return util.JSONResponse{
@ -345,7 +345,7 @@ func SendEventCryptoIDs(
}
// for power level events we need to replace the userID with the pseudoID
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil {
return util.JSONResponse{

View file

@ -214,7 +214,7 @@ func OnIncomingStateTypeRequest(
}
// Translate user ID state keys to room keys in pseudo ID rooms
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs {
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
parsedRoomID, err := spec.NewRoomID(roomID)
if err != nil {
return util.JSONResponse{

View file

@ -314,7 +314,7 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
}
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID)
return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
}
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
@ -328,7 +328,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
roomVersion = roomInfo.RoomVersion
}
}
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID)
if err != nil {
return fclient.SigningIdentity{}, err

View file

@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
}
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember {
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
mapping := gomatrixserverlib.MemberContent{}
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
return err

View file

@ -69,7 +69,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
return nil, spec.InternalServerError{Err: err.Error()}
}
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
bytes := spec.Base64Bytes{}
err = bytes.Decode(string(senderID))
@ -152,7 +152,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
}
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID,
UserID: userID.String(),

View file

@ -577,7 +577,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
if err == nil && info != nil {
switch info.RoomVersion {
case gomatrixserverlib.RoomVersionPseudoIDs:
case gomatrixserverlib.RoomVersionCryptoIDs:
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
if queryErr == nil {
checkInvitePending = true
@ -664,7 +664,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
identity := r.Cfg.Matrix.SigningIdentity
// at this point we know we have an existing room
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID,
UserID: userID.String(),

View file

@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
}
switch version {
case gomatrixserverlib.RoomVersionPseudoIDs:
case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil {
return nil, err

View file

@ -17,7 +17,7 @@ type RoomServer struct {
func (c *RoomServer) Defaults(opts DefaultOpts) {
//c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionPseudoIDs
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
if opts.Generate {
if !opts.SingleDatabase {
c.Database.ConnectionString = "file:roomserver.db"

View file

@ -46,13 +46,13 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return nil
}
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID)
// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
if err != nil {
return err
}
res.OTPseudoIDsCount = count.KeyCount
res.OTCryptoIDsCount = count.KeyCount
return nil
}

View file

@ -51,8 +51,8 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil
func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimeCryptoIDsCount{}, nil
}
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil

View file

@ -41,7 +41,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
return from
}
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
return from

View file

@ -280,7 +280,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
}
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
}

View file

@ -112,8 +112,8 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
return nil
}
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil
func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimeCryptoIDsCount{}, nil
}
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {

View file

@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
// TODO: Set Signatures & Hashes fields
}
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
err := updatePseudoIDs(&ce, se, userIDForSender, format)
if err != nil {
return nil, err
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
return nil, err
}
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
for i, ev := range inviteStateEvents {
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
if userIDErr != nil {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/api"
@ -365,7 +366,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
}
func (r Response) MarshalJSON() ([]byte, error) {
@ -428,7 +429,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{}
res.OTPseudoIDsCount = map[string]int{}
res.OTCryptoIDsCount = map[string]int{}
return &res
}
@ -532,7 +533,7 @@ type InviteResponse struct {
InviteState struct {
Events []json.RawMessage `json:"events"`
} `json:"invite_state"`
OneTimePseudoID string `json:"one_time_pseudoid,omitempty"`
OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
}
// NewInviteResponse creates an empty response with initialised arrays.
@ -540,13 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
res.OneTimePseudoID = *event.PDU.StateKey()
logrus.Infof("Room version: %s", event.Version())
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
res.OneTimeCryptoID = *event.PDU.StateKey()
}
// First see if there's invite_room_state in the unsigned key of the invite.
// If there is then unmarshal it into the response. This will contain the
// partial room state such as join rules, room name etc.
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation {
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat)

View file

@ -51,7 +51,7 @@ type AppserviceUserAPI interface {
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
}
// api functions required by the media api
@ -670,7 +670,7 @@ type UploadDeviceKeysAPI interface {
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError)
QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
@ -774,7 +774,7 @@ type OneTimeKeys struct {
KeyJSON map[string]json.RawMessage
}
type OneTimePseudoIDs struct {
type OneTimeCryptoIDs struct {
// The user who owns this device
UserID string
// A map of algorithm:key_id => key JSON
@ -788,7 +788,7 @@ func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
}
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
@ -807,7 +807,7 @@ type OneTimeKeysCount struct {
KeyCount map[string]int
}
type OneTimePseudoIDsCount struct {
type OneTimeCryptoIDsCount struct {
// The user who owns this device
UserID string
// algorithm to count e.g:
@ -823,7 +823,7 @@ type PerformUploadKeysRequest struct {
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
OneTimePseudoIDs []OneTimePseudoIDs
OneTimeCryptoIDs []OneTimeCryptoIDs
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@ -838,7 +838,7 @@ type PerformUploadKeysResponse struct {
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
OneTimePseudoIDCounts []OneTimePseudoIDsCount
OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain

View file

@ -57,21 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
if len(req.OneTimePseudoIDs) > 0 {
a.uploadOneTimePseudoIDs(ctx, req, res)
if len(req.OneTimeCryptoIDs) > 0 {
a.uploadOneTimeCryptoIDs(ctx, req, res)
}
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil {
return err
}
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
return nil
}
@ -193,11 +193,11 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
return nil
}
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) {
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID)
func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
if err != nil {
return api.OneTimePseudoIDsCount{}, &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
return api.OneTimeCryptoIDsCount{}, &api.KeyError{
Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
}
}
return *count, nil
@ -796,26 +796,26 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
}
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if len(req.OneTimePseudoIDs) == 0 {
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
if len(req.OneTimeCryptoIDs) == 0 {
counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err),
Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
}
}
if counts != nil {
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
}
return
}
for _, key := range req.OneTimePseudoIDs {
for _, key := range req.OneTimeCryptoIDs {
// grab existing keys based on (user/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
@ -823,10 +823,10 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time pseudoIDs: " + err.Error(),
Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
})
continue
}
@ -834,22 +834,22 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
// if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
})
continue
}
}
// store one-time keys
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key)
counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()),
Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
}
}
@ -857,16 +857,16 @@ type Ed25519Key struct {
Key spec.Base64Bytes `json:"key"`
}
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
pseudoID, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519")
func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
if err != nil {
return "", err
}
logrus.Infof("Claimed one time pseuodID: %s", pseudoID)
logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
if pseudoID != nil {
for key, value := range pseudoID.KeyJSON {
if cryptoID != nil {
for key, value := range cryptoID.KeyJSON {
keyParts := strings.Split(key, ":")
if keyParts[0] == "ed25519" {
var key_bytes Ed25519Key
@ -885,7 +885,7 @@ func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.
}
}
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String())
return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {

View file

@ -175,10 +175,10 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error)
ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -27,78 +27,78 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimePseudoIDsSchema = `
-- Stores one-time pseudoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
var oneTimeCryptoIDsSchema = `
-- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm.
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm)
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
`
const upsertPseudoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
" DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
const selectOneTimeCryptoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
const selectPseudoIDsCountSQL = "" +
const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct {
type oneTimeCryptoIDsStatements struct {
db *sql.DB
upsertPseudoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt
upsertCryptoIDsStmt *sql.Stmt
selectCryptoIDsStmt *sql.Stmt
selectCryptoIDsCountStmt *sql.Stmt
selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimeCryptoIDsStmt *sql.Stmt
}
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
s := &oneTimePseudoIDsStatements{
func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimeCryptoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimePseudoIDsSchema)
_, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db)
}
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
result := make(map[string]json.RawMessage)
var (
@ -114,16 +114,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
return result, rows.Err()
}
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimeCryptoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
@ -135,26 +135,26 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
return counts, nil
}
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) {
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{
counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
@ -167,25 +167,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context,
return counts, rows.Err()
}
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -149,7 +149,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
otpid, err := NewPostgresOneTimePseudoIDsTable(db)
otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
if err != nil {
return nil, err
}
@ -176,7 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid,
OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -65,7 +65,7 @@ type Database struct {
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
OneTimePseudoIDsTable tables.OneTimePseudoIDs
OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
@ -946,31 +946,31 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms)
func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
}
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) {
func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys)
counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
}
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) {
var result *api.OneTimePseudoIDs
func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
var result *api.OneTimeCryptoIDs
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm)
keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
if err != nil {
return err
}
if keyJSON != nil {
result = &api.OneTimePseudoIDs{
result = &api.OneTimeCryptoIDs{
UserID: userID.String(),
KeyJSON: keyJSON,
}

View file

@ -27,9 +27,9 @@ import (
"github.com/sirupsen/logrus"
)
var oneTimePseudoIDsSchema = `
-- Stores one-time pseudoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
var oneTimeCryptoIDsSchema = `
-- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
@ -39,66 +39,66 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
`
const upsertPseudoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1"
const selectOneTimeCryptoIDsSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
const selectPseudoIDsCountSQL = "" +
const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct {
type oneTimeCryptoIDsStatements struct {
db *sql.DB
upsertPseudoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt
upsertCryptoIDsStmt *sql.Stmt
selectCryptoIDsStmt *sql.Stmt
selectCryptoIDsCountStmt *sql.Stmt
selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimeCryptoIDsStmt *sql.Stmt
}
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
s := &oneTimePseudoIDsStatements{
func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimeCryptoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimePseudoIDsSchema)
_, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db)
}
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID)
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
@ -121,16 +121,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
return result, rows.Err()
}
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimeCryptoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
@ -142,28 +142,28 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
return counts, nil
}
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs,
) (*api.OneTimePseudoIDsCount, error) {
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{
counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
@ -176,25 +176,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
return counts, rows.Err()
}
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
logrus.Warnf("No rows found for one time pseudoIDs")
logrus.Warnf("No rows found for one time cryptoIDs")
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
if err != nil {
return nil, err
}
if keyJSON == "" {
logrus.Warnf("Empty key JSON for one time pseudoIDs")
logrus.Warnf("Empty key JSON for one time cryptoIDs")
return nil, nil
}
return map[string]json.RawMessage{
@ -202,7 +202,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
}, err
}
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -146,7 +146,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
otpid, err := NewSqliteOneTimePseudoIDsTable(db)
otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
if err != nil {
return nil, err
}
@ -173,7 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid,
OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -760,29 +760,29 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
})
}
func TestOneTimePseudoIDs(t *testing.T) {
func TestOneTimeCryptoIDs(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
otk := api.OneTimePseudoIDs{
otk := api.OneTimeCryptoIDs{
UserID: userID,
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
// Add a one time pseudoID to the DB
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
_, err := db.StoreOneTimeCryptoIDs(ctx, otk)
MustNotError(t, err)
// Check the count of one time pseudoIDs is correct
count, err := db.OneTimePseudoIDsCount(ctx, userID)
count, err := db.OneTimeCryptoIDsCount(ctx, userID)
MustNotError(t, err)
if count.KeyCount["pseudoid_curve25519"] != 1 {
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
}
// Check the actual pseudoid contents are correct
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
MustNotError(t, err)
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
MustNotError(t, err)

View file

@ -168,12 +168,12 @@ type OneTimeKeys interface {
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
}
type OneTimePseudoIDs interface {
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error
type OneTimeCryptoIDs interface {
SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
}
type DeviceKeys interface {