2020-07-15 11:02:34 +00:00
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
2021-06-07 08:17:46 +00:00
"github.com/lib/pq"
2020-07-15 11:02:34 +00:00
"github.com/matrix-org/dendrite/internal"
2020-09-24 10:10:14 +00:00
"github.com/matrix-org/dendrite/internal/sqlutil"
2020-07-15 11:02:34 +00:00
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var oneTimeKeysSchema = `
-- Stores one - time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
user_id TEXT NOT NULL ,
device_id TEXT NOT NULL ,
key_id TEXT NOT NULL ,
algorithm TEXT NOT NULL ,
ts_added_secs BIGINT NOT NULL ,
key_json TEXT NOT NULL ,
-- Clobber based on 4 - uple of user / device / key / algorithm .
CONSTRAINT keyserver_one_time_keys_unique UNIQUE ( user_id , device_id , key_id , algorithm )
) ;
`
const upsertKeysSQL = "" +
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
" DO UPDATE SET key_json = $6"
const selectKeysSQL = "" +
2021-06-07 08:17:46 +00:00
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
2020-07-15 11:02:34 +00:00
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
2020-07-21 13:47:53 +00:00
const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
2022-02-21 12:30:43 +00:00
const deleteOneTimeKeysSQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
2020-07-15 11:02:34 +00:00
type oneTimeKeysStatements struct {
2020-07-21 13:47:53 +00:00
db * sql . DB
upsertKeysStmt * sql . Stmt
selectKeysStmt * sql . Stmt
selectKeysCountStmt * sql . Stmt
selectKeyByAlgorithmStmt * sql . Stmt
deleteOneTimeKeyStmt * sql . Stmt
2022-02-21 12:30:43 +00:00
deleteOneTimeKeysStmt * sql . Stmt
2020-07-15 11:02:34 +00:00
}
func NewPostgresOneTimeKeysTable ( db * sql . DB ) ( tables . OneTimeKeys , error ) {
s := & oneTimeKeysStatements {
db : db ,
}
_ , err := db . Exec ( oneTimeKeysSchema )
if err != nil {
return nil , err
}
if s . upsertKeysStmt , err = db . Prepare ( upsertKeysSQL ) ; err != nil {
return nil , err
}
if s . selectKeysStmt , err = db . Prepare ( selectKeysSQL ) ; err != nil {
return nil , err
}
if s . selectKeysCountStmt , err = db . Prepare ( selectKeysCountSQL ) ; err != nil {
return nil , err
}
2020-07-21 13:47:53 +00:00
if s . selectKeyByAlgorithmStmt , err = db . Prepare ( selectKeyByAlgorithmSQL ) ; err != nil {
return nil , err
}
if s . deleteOneTimeKeyStmt , err = db . Prepare ( deleteOneTimeKeySQL ) ; err != nil {
return nil , err
}
2022-02-21 12:30:43 +00:00
if s . deleteOneTimeKeysStmt , err = db . Prepare ( deleteOneTimeKeysSQL ) ; err != nil {
return nil , err
}
2020-07-15 11:02:34 +00:00
return s , nil
}
func ( s * oneTimeKeysStatements ) SelectOneTimeKeys ( ctx context . Context , userID , deviceID string , keyIDsWithAlgorithms [ ] string ) ( map [ string ] json . RawMessage , error ) {
2021-06-07 08:17:46 +00:00
rows , err := s . selectKeysStmt . QueryContext ( ctx , userID , deviceID , pq . Array ( keyIDsWithAlgorithms ) )
2020-07-15 11:02:34 +00:00
if err != nil {
return nil , err
}
defer internal . CloseAndLogIfError ( ctx , rows , "selectKeysStmt: rows.close() failed" )
result := make ( map [ string ] json . RawMessage )
2021-06-07 08:17:46 +00:00
var (
algorithmWithID string
keyJSONStr string
)
2020-07-15 11:02:34 +00:00
for rows . Next ( ) {
2021-06-07 08:17:46 +00:00
if err := rows . Scan ( & algorithmWithID , & keyJSONStr ) ; err != nil {
2020-07-15 11:02:34 +00:00
return nil , err
}
2021-06-07 08:17:46 +00:00
result [ algorithmWithID ] = json . RawMessage ( keyJSONStr )
2020-07-15 11:02:34 +00:00
}
return result , rows . Err ( )
}
2020-08-03 11:29:58 +00:00
func ( s * oneTimeKeysStatements ) CountOneTimeKeys ( ctx context . Context , userID , deviceID string ) ( * api . OneTimeKeysCount , error ) {
counts := & api . OneTimeKeysCount {
DeviceID : deviceID ,
UserID : userID ,
KeyCount : make ( map [ string ] int ) ,
}
rows , err := s . selectKeysCountStmt . QueryContext ( ctx , userID , deviceID )
if err != nil {
return nil , err
}
defer internal . CloseAndLogIfError ( ctx , rows , "selectKeysCountStmt: rows.close() failed" )
for rows . Next ( ) {
var algorithm string
var count int
if err = rows . Scan ( & algorithm , & count ) ; err != nil {
return nil , err
}
counts . KeyCount [ algorithm ] = count
}
return counts , nil
}
2020-08-25 09:29:45 +00:00
func ( s * oneTimeKeysStatements ) InsertOneTimeKeys ( ctx context . Context , txn * sql . Tx , keys api . OneTimeKeys ) ( * api . OneTimeKeysCount , error ) {
2020-07-15 11:02:34 +00:00
now := time . Now ( ) . Unix ( )
counts := & api . OneTimeKeysCount {
DeviceID : keys . DeviceID ,
UserID : keys . UserID ,
KeyCount : make ( map [ string ] int ) ,
}
2020-08-25 09:29:45 +00:00
for keyIDWithAlgo , keyJSON := range keys . KeyJSON {
algo , keyID := keys . Split ( keyIDWithAlgo )
2020-09-24 10:10:14 +00:00
_ , err := sqlutil . TxStmt ( txn , s . upsertKeysStmt ) . ExecContext (
2020-08-25 09:29:45 +00:00
ctx , keys . UserID , keys . DeviceID , keyID , algo , now , string ( keyJSON ) ,
)
2020-07-15 11:02:34 +00:00
if err != nil {
2020-08-25 09:29:45 +00:00
return nil , err
2020-07-15 11:02:34 +00:00
}
2020-08-25 09:29:45 +00:00
}
2020-09-24 10:10:14 +00:00
rows , err := sqlutil . TxStmt ( txn , s . selectKeysCountStmt ) . QueryContext ( ctx , keys . UserID , keys . DeviceID )
2020-08-25 09:29:45 +00:00
if err != nil {
return nil , err
}
defer internal . CloseAndLogIfError ( ctx , rows , "selectKeysCountStmt: rows.close() failed" )
for rows . Next ( ) {
var algorithm string
var count int
if err = rows . Scan ( & algorithm , & count ) ; err != nil {
return nil , err
2020-07-15 11:02:34 +00:00
}
2020-08-25 09:29:45 +00:00
counts . KeyCount [ algorithm ] = count
}
2020-07-15 11:02:34 +00:00
2020-08-25 09:29:45 +00:00
return counts , rows . Err ( )
2020-07-15 11:02:34 +00:00
}
2020-07-21 13:47:53 +00:00
func ( s * oneTimeKeysStatements ) SelectAndDeleteOneTimeKey (
ctx context . Context , txn * sql . Tx , userID , deviceID , algorithm string ,
) ( map [ string ] json . RawMessage , error ) {
var keyID string
var keyJSON string
2020-09-24 10:10:14 +00:00
err := sqlutil . TxStmtContext ( ctx , txn , s . selectKeyByAlgorithmStmt ) . QueryRowContext ( ctx , userID , deviceID , algorithm ) . Scan ( & keyID , & keyJSON )
2020-07-21 13:47:53 +00:00
if err != nil {
if err == sql . ErrNoRows {
return nil , nil
}
return nil , err
}
2020-09-24 10:10:14 +00:00
_ , err = sqlutil . TxStmtContext ( ctx , txn , s . deleteOneTimeKeyStmt ) . ExecContext ( ctx , userID , deviceID , algorithm , keyID )
2020-07-21 13:47:53 +00:00
return map [ string ] json . RawMessage {
algorithm + ":" + keyID : json . RawMessage ( keyJSON ) ,
} , err
}
2022-02-21 12:30:43 +00:00
func ( s * oneTimeKeysStatements ) DeleteOneTimeKeys ( ctx context . Context , txn * sql . Tx , userID , deviceID string ) error {
_ , err := sqlutil . TxStmt ( txn , s . deleteOneTimeKeysStmt ) . ExecContext ( ctx , userID , deviceID )
return err
}