mirror of
https://github.com/matrix-org/dendrite
synced 2025-01-18 16:04:02 +00:00
72285b2659
Sister PR to https://github.com/matrix-org/gomatrixserverlib/pull/364 Read this commit by commit to avoid going insane.
89 lines
2.7 KiB
Go
89 lines
2.7 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
"github.com/matrix-org/dendrite/userapi/api"
|
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const openIDTokenSchema = `
|
|
-- Stores data about openid tokens issued for accounts.
|
|
CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|
-- The value of the token issued to a user
|
|
token TEXT NOT NULL PRIMARY KEY,
|
|
-- The Matrix user ID for this account
|
|
localpart TEXT NOT NULL,
|
|
server_name TEXT NOT NULL,
|
|
-- When the token expires, as a unix timestamp (ms resolution).
|
|
token_expires_at_ms BIGINT NOT NULL
|
|
);
|
|
`
|
|
|
|
const insertOpenIDTokenSQL = "" +
|
|
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
|
|
|
const selectOpenIDTokenSQL = "" +
|
|
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
|
|
|
type openIDTokenStatements struct {
|
|
insertTokenStmt *sql.Stmt
|
|
selectTokenStmt *sql.Stmt
|
|
serverName spec.ServerName
|
|
}
|
|
|
|
func NewPostgresOpenIDTable(db *sql.DB, serverName spec.ServerName) (tables.OpenIDTable, error) {
|
|
s := &openIDTokenStatements{
|
|
serverName: serverName,
|
|
}
|
|
_, err := db.Exec(openIDTokenSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s, sqlutil.StatementList{
|
|
{&s.insertTokenStmt, insertOpenIDTokenSQL},
|
|
{&s.selectTokenStmt, selectOpenIDTokenSQL},
|
|
}.Prepare(db)
|
|
}
|
|
|
|
// insertToken inserts a new OpenID Connect token to the DB.
|
|
// Returns new token, otherwise returns error if the token already exists.
|
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
|
ctx context.Context,
|
|
txn *sql.Tx,
|
|
token, localpart string, serverName spec.ServerName,
|
|
expiresAtMS int64,
|
|
) (err error) {
|
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
|
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
|
return
|
|
}
|
|
|
|
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
|
// Returns the existing token's attributes, or err if no token is found
|
|
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|
ctx context.Context,
|
|
token string,
|
|
) (*api.OpenIDTokenAttributes, error) {
|
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
|
var localpart string
|
|
var serverName spec.ServerName
|
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
|
&localpart, &serverName,
|
|
&openIDTokenAttrs.ExpiresAtMS,
|
|
)
|
|
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
|
if err != nil {
|
|
if err != sql.ErrNoRows {
|
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &openIDTokenAttrs, nil
|
|
}
|