SSO Support (#4881)

* Added OAUTH2 OIDC support

* Fixes and improvements based on review feedback

* use derive_new::new instead of TypedBuilder

* merge migrations into a single file

* fixes based on review feedback

* remove unnecessary hostname_ui config

* improvement based on review feedback

* improvements based on review feedback

* delete user oauth accounts at account deletion

* fixes and improvements based on review feedback

* removed auto_approve_application

* support registration application with sso

* improvements based on review feedback

* making the TokenResponse an internal struct as it should be

* remove duplicate struct

* prevent oauth linking to unverified accounts

* switched to manually entered username and removed the oauth name claim

* fix cargo fmt

* fix compile error

* improvements based on review feedback

* fixes and improvements based on review feedback

---------

Co-authored-by: privacyguard <privacyguard@users.noreply.github.com>
This commit is contained in:
privacyguard 2024-09-18 15:52:33 +03:00 committed by GitHub
parent 6454a4d43d
commit b26aaac523
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 1235 additions and 166 deletions

7
Cargo.lock generated
View file

@ -2545,6 +2545,9 @@ dependencies = [
"lemmy_db_views_actor", "lemmy_db_views_actor",
"lemmy_utils", "lemmy_utils",
"moka", "moka",
"serde",
"serde_json",
"serde_with",
"tracing", "tracing",
"url", "url",
"uuid", "uuid",
@ -3314,9 +3317,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.0" version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]] [[package]]
name = "overload" name = "overload"

View file

@ -28,11 +28,13 @@ pub async fn change_password(
} }
// Check the old password // Check the old password
let valid: bool = verify( let valid: bool = if let Some(password_encrypted) = &local_user_view.local_user.password_encrypted
&data.old_password, {
&local_user_view.local_user.password_encrypted, verify(&data.old_password, password_encrypted).unwrap_or(false)
) } else {
.unwrap_or(false); data.old_password.is_empty()
};
if !valid { if !valid {
Err(LemmyErrorType::IncorrectLogin)? Err(LemmyErrorType::IncorrectLogin)?
} }

View file

@ -1,4 +1,4 @@
use crate::{check_totp_2fa_valid, local_user::check_email_verified}; use crate::check_totp_2fa_valid;
use actix_web::{ use actix_web::{
web::{Data, Json}, web::{Data, Json},
HttpRequest, HttpRequest,
@ -8,12 +8,7 @@ use lemmy_api_common::{
claims::Claims, claims::Claims,
context::LemmyContext, context::LemmyContext,
person::{Login, LoginResponse}, person::{Login, LoginResponse},
utils::check_user_valid, utils::{check_email_verified, check_registration_application, check_user_valid},
};
use lemmy_db_schema::{
source::{local_site::LocalSite, registration_application::RegistrationApplication},
utils::DbPool,
RegistrationMode,
}; };
use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_db_views::structs::{LocalUserView, SiteView};
use lemmy_utils::error::{LemmyErrorType, LemmyResult}; use lemmy_utils::error::{LemmyErrorType, LemmyResult};
@ -34,11 +29,12 @@ pub async fn login(
.ok_or(LemmyErrorType::IncorrectLogin)?; .ok_or(LemmyErrorType::IncorrectLogin)?;
// Verify the password // Verify the password
let valid: bool = verify( let valid: bool = local_user_view
&data.password, .local_user
&local_user_view.local_user.password_encrypted, .password_encrypted
) .as_ref()
.unwrap_or(false); .and_then(|password_encrypted| verify(&data.password, password_encrypted).ok())
.unwrap_or(false);
if !valid { if !valid {
Err(LemmyErrorType::IncorrectLogin)? Err(LemmyErrorType::IncorrectLogin)?
} }
@ -65,28 +61,3 @@ pub async fn login(
registration_created: false, registration_created: false,
})) }))
} }
async fn check_registration_application(
local_user_view: &LocalUserView,
local_site: &LocalSite,
pool: &mut DbPool<'_>,
) -> LemmyResult<()> {
if (local_site.registration_mode == RegistrationMode::RequireApplication
|| local_site.registration_mode == RegistrationMode::Closed)
&& !local_user_view.local_user.accepted_application
&& !local_user_view.local_user.admin
{
// Fetch the registration application. If no admin id is present its still pending. Otherwise it
// was processed (either accepted or denied).
let local_user_id = local_user_view.local_user.id;
let registration = RegistrationApplication::find_by_local_user_id(pool, local_user_id)
.await?
.ok_or(LemmyErrorType::CouldntFindRegistrationApplication)?;
if registration.admin_id.is_some() {
Err(LemmyErrorType::RegistrationDenied(registration.deny_reason))?
} else {
Err(LemmyErrorType::RegistrationApplicationIsPending)?
}
}
Ok(())
}

View file

@ -1,6 +1,3 @@
use lemmy_db_views::structs::{LocalUserView, SiteView};
use lemmy_utils::{error::LemmyResult, LemmyErrorType};
pub mod add_admin; pub mod add_admin;
pub mod ban_person; pub mod ban_person;
pub mod block; pub mod block;
@ -20,15 +17,3 @@ pub mod save_settings;
pub mod update_totp; pub mod update_totp;
pub mod validate_auth; pub mod validate_auth;
pub mod verify_email; pub mod verify_email;
/// Check if the user's email is verified if email verification is turned on
/// However, skip checking verification if the user is an admin
fn check_email_verified(local_user_view: &LocalUserView, site_view: &SiteView) -> LemmyResult<()> {
if !local_user_view.local_user.admin
&& site_view.local_site.require_email_verification
&& !local_user_view.local_user.email_verified
{
Err(LemmyErrorType::EmailNotVerified)?
}
Ok(())
}

View file

@ -1,9 +1,8 @@
use crate::local_user::check_email_verified;
use actix_web::web::{Data, Json}; use actix_web::web::{Data, Json};
use lemmy_api_common::{ use lemmy_api_common::{
context::LemmyContext, context::LemmyContext,
person::PasswordReset, person::PasswordReset,
utils::send_password_reset_email, utils::{check_email_verified, send_password_reset_email},
SuccessResponse, SuccessResponse,
}; };
use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_db_views::structs::{LocalUserView, SiteView};

View file

@ -7,6 +7,7 @@ use lemmy_db_schema::{
local_site_url_blocklist::LocalSiteUrlBlocklist, local_site_url_blocklist::LocalSiteUrlBlocklist,
local_user::{LocalUser, LocalUserUpdateForm}, local_user::{LocalUser, LocalUserUpdateForm},
moderator::{ModAdd, ModAddForm}, moderator::{ModAdd, ModAddForm},
oauth_provider::OAuthProvider,
tagline::Tagline, tagline::Tagline,
}, },
traits::Crud, traits::Crud,
@ -63,6 +64,7 @@ pub async fn leave_admin(
let taglines = Tagline::get_all(&mut context.pool(), site_view.local_site.id).await?; let taglines = Tagline::get_all(&mut context.pool(), site_view.local_site.id).await?;
let custom_emojis = let custom_emojis =
CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?; CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?;
let oauth_providers = OAuthProvider::get_all_public(&mut context.pool()).await?;
let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?; let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?;
Ok(Json(GetSiteResponse { Ok(Json(GetSiteResponse {
@ -74,6 +76,8 @@ pub async fn leave_admin(
discussion_languages, discussion_languages,
taglines, taglines,
custom_emojis, custom_emojis,
oauth_providers: Some(oauth_providers),
admin_oauth_providers: None,
blocked_urls, blocked_urls,
})) }))
} }

View file

@ -7,6 +7,7 @@ pub mod community;
#[cfg(feature = "full")] #[cfg(feature = "full")]
pub mod context; pub mod context;
pub mod custom_emoji; pub mod custom_emoji;
pub mod oauth_provider;
pub mod person; pub mod person;
pub mod post; pub mod post;
pub mod private_message; pub mod private_message;

View file

@ -0,0 +1,69 @@
use lemmy_db_schema::newtypes::OAuthProviderId;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[cfg(feature = "full")]
use ts_rs::TS;
use url::Url;
#[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
/// Create an external auth method.
pub struct CreateOAuthProvider {
pub display_name: String,
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub userinfo_endpoint: String,
pub id_claim: String,
pub client_id: String,
pub client_secret: String,
pub scopes: String,
pub auto_verify_email: bool,
pub account_linking_enabled: bool,
pub enabled: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
/// Edit an external auth method.
pub struct EditOAuthProvider {
pub id: OAuthProviderId,
pub display_name: Option<String>,
pub authorization_endpoint: Option<String>,
pub token_endpoint: Option<String>,
pub userinfo_endpoint: Option<String>,
pub id_claim: Option<String>,
pub client_secret: Option<String>,
pub scopes: Option<String>,
pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>,
pub enabled: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
/// Delete an external auth method.
pub struct DeleteOAuthProvider {
pub id: OAuthProviderId,
}
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
/// Logging in with an OAuth 2.0 authorization
pub struct AuthenticateWithOauth {
pub code: String,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub oauth_provider_id: OAuthProviderId,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub redirect_uri: Url,
pub show_nsfw: Option<bool>,
/// Username is mandatory at registration time
pub username: Option<String>,
/// An answer is mandatory if require application is enabled on the server
pub answer: Option<String>,
}

View file

@ -44,6 +44,7 @@ pub fn client_builder(settings: &Settings) -> ClientBuilder {
.user_agent(user_agent.clone()) .user_agent(user_agent.clone())
.timeout(REQWEST_TIMEOUT) .timeout(REQWEST_TIMEOUT)
.connect_timeout(REQWEST_TIMEOUT) .connect_timeout(REQWEST_TIMEOUT)
.use_rustls_tls()
} }
/// Fetches metadata for the given link and optionally generates thumbnail. /// Fetches metadata for the given link and optionally generates thumbnail.

View file

@ -16,6 +16,7 @@ use lemmy_db_schema::{
instance::Instance, instance::Instance,
language::Language, language::Language,
local_site_url_blocklist::LocalSiteUrlBlocklist, local_site_url_blocklist::LocalSiteUrlBlocklist,
oauth_provider::{OAuthProvider, PublicOAuthProvider},
person::Person, person::Person,
tagline::Tagline, tagline::Tagline,
}, },
@ -200,6 +201,7 @@ pub struct CreateSite {
pub blocked_instances: Option<Vec<String>>, pub blocked_instances: Option<Vec<String>>,
pub taglines: Option<Vec<String>>, pub taglines: Option<Vec<String>>,
pub registration_mode: Option<RegistrationMode>, pub registration_mode: Option<RegistrationMode>,
pub oauth_registration: Option<bool>,
pub content_warning: Option<String>, pub content_warning: Option<String>,
pub default_post_listing_mode: Option<PostListingMode>, pub default_post_listing_mode: Option<PostListingMode>,
} }
@ -282,6 +284,8 @@ pub struct EditSite {
/// A list of taglines shown at the top of the front page. /// A list of taglines shown at the top of the front page.
pub taglines: Option<Vec<String>>, pub taglines: Option<Vec<String>>,
pub registration_mode: Option<RegistrationMode>, pub registration_mode: Option<RegistrationMode>,
/// Whether or not external auth methods can auto-register users.
pub oauth_registration: Option<bool>,
/// Whether to email admins for new reports. /// Whether to email admins for new reports.
pub reports_email_admins: Option<bool>, pub reports_email_admins: Option<bool>,
/// If present, nsfw content is visible by default. Should be displayed by frontends/clients /// If present, nsfw content is visible by default. Should be displayed by frontends/clients
@ -316,6 +320,9 @@ pub struct GetSiteResponse {
pub taglines: Vec<Tagline>, pub taglines: Vec<Tagline>,
/// A list of custom emojis your site supports. /// A list of custom emojis your site supports.
pub custom_emojis: Vec<CustomEmojiView>, pub custom_emojis: Vec<CustomEmojiView>,
/// A list of external auth methods your site supports.
pub oauth_providers: Option<Vec<PublicOAuthProvider>>,
pub admin_oauth_providers: Option<Vec<OAuthProvider>>,
pub blocked_urls: Vec<LocalSiteUrlBlocklist>, pub blocked_urls: Vec<LocalSiteUrlBlocklist>,
} }

View file

@ -23,18 +23,21 @@ use lemmy_db_schema::{
local_site::LocalSite, local_site::LocalSite,
local_site_rate_limit::LocalSiteRateLimit, local_site_rate_limit::LocalSiteRateLimit,
local_site_url_blocklist::LocalSiteUrlBlocklist, local_site_url_blocklist::LocalSiteUrlBlocklist,
oauth_account::OAuthAccount,
password_reset_request::PasswordResetRequest, password_reset_request::PasswordResetRequest,
person::{Person, PersonUpdateForm}, person::{Person, PersonUpdateForm},
person_block::PersonBlock, person_block::PersonBlock,
post::{Post, PostRead}, post::{Post, PostRead},
registration_application::RegistrationApplication,
site::Site, site::Site,
}, },
traits::Crud, traits::Crud,
utils::DbPool, utils::DbPool,
RegistrationMode,
}; };
use lemmy_db_views::{ use lemmy_db_views::{
comment_view::CommentQuery, comment_view::CommentQuery,
structs::{LocalImageView, LocalUserView}, structs::{LocalImageView, LocalUserView, SiteView},
}; };
use lemmy_db_views_actor::structs::{ use lemmy_db_views_actor::structs::{
CommunityModeratorView, CommunityModeratorView,
@ -192,6 +195,46 @@ pub fn check_user_valid(person: &Person) -> LemmyResult<()> {
} }
} }
/// Check if the user's email is verified if email verification is turned on
/// However, skip checking verification if the user is an admin
pub fn check_email_verified(
local_user_view: &LocalUserView,
site_view: &SiteView,
) -> LemmyResult<()> {
if !local_user_view.local_user.admin
&& site_view.local_site.require_email_verification
&& !local_user_view.local_user.email_verified
{
Err(LemmyErrorType::EmailNotVerified)?
}
Ok(())
}
pub async fn check_registration_application(
local_user_view: &LocalUserView,
local_site: &LocalSite,
pool: &mut DbPool<'_>,
) -> LemmyResult<()> {
if (local_site.registration_mode == RegistrationMode::RequireApplication
|| local_site.registration_mode == RegistrationMode::Closed)
&& !local_user_view.local_user.accepted_application
&& !local_user_view.local_user.admin
{
// Fetch the registration application. If no admin id is present its still pending. Otherwise it
// was processed (either accepted or denied).
let local_user_id = local_user_view.local_user.id;
let registration = RegistrationApplication::find_by_local_user_id(pool, local_user_id)
.await?
.ok_or(LemmyErrorType::CouldntFindRegistrationApplication)?;
if registration.admin_id.is_some() {
Err(LemmyErrorType::RegistrationDenied(registration.deny_reason))?
} else {
Err(LemmyErrorType::RegistrationApplicationIsPending)?
}
}
Ok(())
}
/// Checks that a normal user action (eg posting or voting) is allowed in a given community. /// Checks that a normal user action (eg posting or voting) is allowed in a given community.
/// ///
/// In particular it checks that neither the user nor community are banned or deleted, and that /// In particular it checks that neither the user nor community are banned or deleted, and that
@ -852,6 +895,11 @@ pub async fn purge_user_account(person_id: PersonId, context: &LemmyContext) ->
// Leave communities they mod // Leave communities they mod
CommunityModerator::leave_all_communities(pool, person_id).await?; CommunityModerator::leave_all_communities(pool, person_id).await?;
// Delete the oauth accounts linked to the local user
if let Ok(Some(local_user)) = LocalUserView::read_person(pool, person_id).await {
OAuthAccount::delete_user_accounts(pool, local_user.local_user.id).await?;
}
Person::delete_account(pool, person_id).await?; Person::delete_account(pool, person_id).await?;
Ok(()) Ok(())

View file

@ -29,6 +29,9 @@ moka.workspace = true
anyhow.workspace = true anyhow.workspace = true
webmention = "0.6.0" webmention = "0.6.0"
accept-language = "3.1.0" accept-language = "3.1.0"
serde_json = { workspace = true }
serde = { workspace = true }
serde_with = { workspace = true }
[package.metadata.cargo-machete] [package.metadata.cargo-machete]
ignored = ["futures"] ignored = ["futures"]

View file

@ -1,6 +1,7 @@
pub mod comment; pub mod comment;
pub mod community; pub mod community;
pub mod custom_emoji; pub mod custom_emoji;
pub mod oauth_provider;
pub mod post; pub mod post;
pub mod private_message; pub mod private_message;
pub mod site; pub mod site;

View file

@ -0,0 +1,42 @@
use activitypub_federation::config::Data;
use actix_web::web::Json;
use lemmy_api_common::{
context::LemmyContext,
oauth_provider::CreateOAuthProvider,
utils::is_admin,
};
use lemmy_db_schema::{
source::oauth_provider::{OAuthProvider, OAuthProviderInsertForm},
traits::Crud,
};
use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::error::LemmyError;
use url::Url;
#[tracing::instrument(skip(context))]
pub async fn create_oauth_provider(
data: Json<CreateOAuthProvider>,
context: Data<LemmyContext>,
local_user_view: LocalUserView,
) -> Result<Json<OAuthProvider>, LemmyError> {
// Make sure user is an admin
is_admin(&local_user_view)?;
let cloned_data = data.clone();
let oauth_provider_form = OAuthProviderInsertForm {
display_name: cloned_data.display_name,
issuer: Url::parse(&cloned_data.issuer)?.into(),
authorization_endpoint: Url::parse(&cloned_data.authorization_endpoint)?.into(),
token_endpoint: Url::parse(&cloned_data.token_endpoint)?.into(),
userinfo_endpoint: Url::parse(&cloned_data.userinfo_endpoint)?.into(),
id_claim: cloned_data.id_claim,
client_id: data.client_id.to_string(),
client_secret: data.client_secret.to_string(),
scopes: data.scopes.to_string(),
auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled,
enabled: data.enabled,
};
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
Ok(Json(oauth_provider))
}

View file

@ -0,0 +1,25 @@
use activitypub_federation::config::Data;
use actix_web::web::Json;
use lemmy_api_common::{
context::LemmyContext,
oauth_provider::DeleteOAuthProvider,
utils::is_admin,
SuccessResponse,
};
use lemmy_db_schema::{source::oauth_provider::OAuthProvider, traits::Crud};
use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::error::{LemmyError, LemmyErrorExt, LemmyErrorType};
#[tracing::instrument(skip(context))]
pub async fn delete_oauth_provider(
data: Json<DeleteOAuthProvider>,
context: Data<LemmyContext>,
local_user_view: LocalUserView,
) -> Result<Json<SuccessResponse>, LemmyError> {
// Make sure user is an admin
is_admin(&local_user_view)?;
OAuthProvider::delete(&mut context.pool(), data.id)
.await
.with_lemmy_type(LemmyErrorType::CouldntDeleteOauthProvider)?;
Ok(Json(SuccessResponse::default()))
}

View file

@ -0,0 +1,3 @@
pub mod create;
pub mod delete;
pub mod update;

View file

@ -0,0 +1,44 @@
use activitypub_federation::config::Data;
use actix_web::web::Json;
use lemmy_api_common::{context::LemmyContext, oauth_provider::EditOAuthProvider, utils::is_admin};
use lemmy_db_schema::{
source::oauth_provider::{OAuthProvider, OAuthProviderUpdateForm},
traits::Crud,
utils::{diesel_required_string_update, diesel_required_url_update, naive_now},
};
use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::{error::LemmyError, LemmyErrorType};
#[tracing::instrument(skip(context))]
pub async fn update_oauth_provider(
data: Json<EditOAuthProvider>,
context: Data<LemmyContext>,
local_user_view: LocalUserView,
) -> Result<Json<OAuthProvider>, LemmyError> {
// Make sure user is an admin
is_admin(&local_user_view)?;
let cloned_data = data.clone();
let oauth_provider_form = OAuthProviderUpdateForm {
display_name: diesel_required_string_update(cloned_data.display_name.as_deref()),
authorization_endpoint: diesel_required_url_update(
cloned_data.authorization_endpoint.as_deref(),
)?,
token_endpoint: diesel_required_url_update(cloned_data.token_endpoint.as_deref())?,
userinfo_endpoint: diesel_required_url_update(cloned_data.userinfo_endpoint.as_deref())?,
id_claim: diesel_required_string_update(data.id_claim.as_deref()),
client_secret: diesel_required_string_update(data.client_secret.as_deref()),
scopes: diesel_required_string_update(data.scopes.as_deref()),
auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled,
enabled: data.enabled,
updated: Some(Some(naive_now())),
};
let update_result =
OAuthProvider::update(&mut context.pool(), data.id, &oauth_provider_form).await?;
let oauth_provider = OAuthProvider::read(&mut context.pool(), update_result.id)
.await?
.ok_or(LemmyErrorType::CouldntFindOauthProvider)?;
Ok(Json(oauth_provider))
}

View file

@ -591,6 +591,7 @@ mod tests {
blocked_instances: None, blocked_instances: None,
taglines: None, taglines: None,
registration_mode: site_registration_mode, registration_mode: site_registration_mode,
oauth_registration: None,
content_warning: None, content_warning: None,
default_post_listing_mode: None, default_post_listing_mode: None,
} }

View file

@ -9,6 +9,7 @@ use lemmy_db_schema::source::{
instance_block::InstanceBlock, instance_block::InstanceBlock,
language::Language, language::Language,
local_site_url_blocklist::LocalSiteUrlBlocklist, local_site_url_blocklist::LocalSiteUrlBlocklist,
oauth_provider::OAuthProvider,
person_block::PersonBlock, person_block::PersonBlock,
tagline::Tagline, tagline::Tagline,
}; };
@ -45,6 +46,10 @@ pub async fn get_site(
let custom_emojis = let custom_emojis =
CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?; CustomEmojiView::get_all(&mut context.pool(), site_view.local_site.id).await?;
let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?; let blocked_urls = LocalSiteUrlBlocklist::get_all(&mut context.pool()).await?;
let admin_oauth_providers = OAuthProvider::get_all(&mut context.pool()).await?;
let oauth_providers =
OAuthProvider::convert_providers_to_public(admin_oauth_providers.clone());
Ok(GetSiteResponse { Ok(GetSiteResponse {
site_view, site_view,
admins, admins,
@ -55,13 +60,15 @@ pub async fn get_site(
taglines, taglines,
custom_emojis, custom_emojis,
blocked_urls, blocked_urls,
oauth_providers: Some(oauth_providers),
admin_oauth_providers: Some(admin_oauth_providers),
}) })
}) })
.await .await
.map_err(|e| anyhow::anyhow!("Failed to construct site response: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to construct site response: {e}"))?;
// Build the local user with parallel queries and add it to site response // Build the local user with parallel queries and add it to site response
site_response.my_user = if let Some(local_user_view) = local_user_view { site_response.my_user = if let Some(ref local_user_view) = local_user_view {
let person_id = local_user_view.person.id; let person_id = local_user_view.person.id;
let local_user_id = local_user_view.local_user.id; let local_user_id = local_user_view.local_user.id;
let pool = &mut context.pool(); let pool = &mut context.pool();
@ -84,7 +91,7 @@ pub async fn get_site(
.with_lemmy_type(LemmyErrorType::SystemErrLogin)?; .with_lemmy_type(LemmyErrorType::SystemErrLogin)?;
Some(MyUserInfo { Some(MyUserInfo {
local_user_view, local_user_view: local_user_view.clone(),
follows, follows,
moderates, moderates,
community_blocks, community_blocks,
@ -96,5 +103,13 @@ pub async fn get_site(
None None
}; };
// filter oauth_providers for public access
if !local_user_view
.map(|l| l.local_user.admin)
.unwrap_or_default()
{
site_response.admin_oauth_providers = None;
}
Ok(Json(site_response)) Ok(Json(site_response))
} }

View file

@ -119,6 +119,7 @@ pub async fn update_site(
captcha_difficulty: data.captcha_difficulty.clone(), captcha_difficulty: data.captcha_difficulty.clone(),
reports_email_admins: data.reports_email_admins, reports_email_admins: data.reports_email_admins,
default_post_listing_mode: data.default_post_listing_mode, default_post_listing_mode: data.default_post_listing_mode,
oauth_registration: data.oauth_registration,
..Default::default() ..Default::default()
}; };
@ -278,6 +279,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -301,6 +303,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -324,6 +327,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -347,6 +351,7 @@ mod tests {
Some(true), Some(true),
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -370,6 +375,7 @@ mod tests {
Some(true), Some(true),
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -393,6 +399,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
Some(RegistrationMode::RequireApplication), Some(RegistrationMode::RequireApplication),
None::<bool>,
), ),
), ),
]; ];
@ -447,6 +454,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -469,6 +477,7 @@ mod tests {
Some(true), Some(true),
Some(String::new()), Some(String::new()),
Some(RegistrationMode::Open), Some(RegistrationMode::Open),
None::<bool>,
), ),
), ),
( (
@ -491,6 +500,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
None::<RegistrationMode>, None::<RegistrationMode>,
None::<bool>,
), ),
), ),
( (
@ -513,6 +523,7 @@ mod tests {
None::<bool>, None::<bool>,
None::<String>, None::<String>,
Some(RegistrationMode::RequireApplication), Some(RegistrationMode::RequireApplication),
None::<bool>,
), ),
), ),
]; ];
@ -561,6 +572,7 @@ mod tests {
site_is_federated: Option<bool>, site_is_federated: Option<bool>,
site_application_question: Option<String>, site_application_question: Option<String>,
site_registration_mode: Option<RegistrationMode>, site_registration_mode: Option<RegistrationMode>,
site_oauth_registration: Option<bool>,
) -> EditSite { ) -> EditSite {
EditSite { EditSite {
name: site_name, name: site_name,
@ -607,6 +619,7 @@ mod tests {
reports_email_admins: None, reports_email_admins: None,
content_warning: None, content_warning: None,
default_post_listing_mode: None, default_post_listing_mode: None,
oauth_registration: site_oauth_registration,
} }
} }
} }

View file

@ -3,8 +3,12 @@ use actix_web::{web::Json, HttpRequest};
use lemmy_api_common::{ use lemmy_api_common::{
claims::Claims, claims::Claims,
context::LemmyContext, context::LemmyContext,
oauth_provider::AuthenticateWithOauth,
person::{LoginResponse, Register}, person::{LoginResponse, Register},
utils::{ utils::{
check_email_verified,
check_registration_application,
check_user_valid,
generate_inbox_url, generate_inbox_url,
generate_local_apub_endpoint, generate_local_apub_endpoint,
generate_shared_inbox_url, generate_shared_inbox_url,
@ -18,11 +22,15 @@ use lemmy_api_common::{
}; };
use lemmy_db_schema::{ use lemmy_db_schema::{
aggregates::structs::PersonAggregates, aggregates::structs::PersonAggregates,
newtypes::{InstanceId, OAuthProviderId},
source::{ source::{
captcha_answer::{CaptchaAnswer, CheckCaptchaAnswer}, captcha_answer::{CaptchaAnswer, CheckCaptchaAnswer},
language::Language, language::Language,
local_site::LocalSite,
local_user::{LocalUser, LocalUserInsertForm}, local_user::{LocalUser, LocalUserInsertForm},
local_user_vote_display_mode::LocalUserVoteDisplayMode, local_user_vote_display_mode::LocalUserVoteDisplayMode,
oauth_account::{OAuthAccount, OAuthAccountInsertForm},
oauth_provider::OAuthProvider,
person::{Person, PersonInsertForm}, person::{Person, PersonInsertForm},
registration_application::{RegistrationApplication, RegistrationApplicationInsertForm}, registration_application::{RegistrationApplication, RegistrationApplicationInsertForm},
}, },
@ -31,15 +39,27 @@ use lemmy_db_schema::{
}; };
use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_db_views::structs::{LocalUserView, SiteView};
use lemmy_utils::{ use lemmy_utils::{
error::{LemmyErrorExt, LemmyErrorType, LemmyResult}, error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult},
utils::{ utils::{
slurs::{check_slurs, check_slurs_opt}, slurs::{check_slurs, check_slurs_opt},
validation::is_valid_actor_name, validation::is_valid_actor_name,
}, },
}; };
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashSet; use std::collections::HashSet;
#[tracing::instrument(skip(context))] #[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
/// Response from OAuth token endpoint
struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<i64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
pub async fn register( pub async fn register(
data: Json<Register>, data: Json<Register>,
req: HttpRequest, req: HttpRequest,
@ -61,8 +81,9 @@ pub async fn register(
Err(LemmyErrorType::EmailRequired)? Err(LemmyErrorType::EmailRequired)?
} }
if local_site.site_setup && require_registration_application && data.answer.is_none() { // make sure the registration answer is provided when the registration application is required
Err(LemmyErrorType::RegistrationApplicationAnswerRequired)? if local_site.site_setup {
validate_registration_answer(require_registration_application, &data.answer)?;
} }
// Make sure passwords match // Make sure passwords match
@ -93,13 +114,9 @@ pub async fn register(
check_slurs(&data.username, &slur_regex)?; check_slurs(&data.username, &slur_regex)?;
check_slurs_opt(&data.answer, &slur_regex)?; check_slurs_opt(&data.answer, &slur_regex)?;
let actor_keypair = generate_actor_keypair()?; if Person::is_username_taken(&mut context.pool(), &data.username).await? {
is_valid_actor_name(&data.username, local_site.actor_name_max_length as usize)?; return Err(LemmyErrorType::UsernameAlreadyExists)?;
let actor_id = generate_local_apub_endpoint( }
EndpointType::Person,
&data.username,
&context.settings().get_protocol_and_hostname(),
)?;
if let Some(email) = &data.email { if let Some(email) = &data.email {
if LocalUser::is_email_taken(&mut context.pool(), email).await? { if LocalUser::is_email_taken(&mut context.pool(), email).await? {
@ -108,49 +125,28 @@ pub async fn register(
} }
// We have to create both a person, and local_user // We have to create both a person, and local_user
let inserted_person = create_person(
// Register the new person data.username.clone(),
let person_form = PersonInsertForm { &local_site,
actor_id: Some(actor_id.clone()), site_view.site.instance_id,
inbox_url: Some(generate_inbox_url(&actor_id)?), &context,
shared_inbox_url: Some(generate_shared_inbox_url(context.settings())?), )
private_key: Some(actor_keypair.private_key), .await?;
..PersonInsertForm::new(
data.username.clone(),
actor_keypair.public_key,
site_view.site.instance_id,
)
};
// insert the person
let inserted_person = Person::create(&mut context.pool(), &person_form)
.await
.with_lemmy_type(LemmyErrorType::UserAlreadyExists)?;
// Automatically set their application as accepted, if they created this with open registration. // Automatically set their application as accepted, if they created this with open registration.
// Also fixes a bug which allows users to log in when registrations are changed to closed. // Also fixes a bug which allows users to log in when registrations are changed to closed.
let accepted_application = Some(!require_registration_application); let accepted_application = Some(!require_registration_application);
// Get the user's preferred language using the Accept-Language header
let language_tags: Vec<String> = req
.headers()
.get("Accept-Language")
.map(|hdr| accept_language::parse(hdr.to_str().unwrap_or_default()))
.iter()
.flatten()
// Remove the optional region code
.map(|lang_str| lang_str.split('-').next().unwrap_or_default().to_string())
.collect();
// Show nsfw content if param is true, or if content_warning exists // Show nsfw content if param is true, or if content_warning exists
let show_nsfw = data let show_nsfw = data
.show_nsfw .show_nsfw
.unwrap_or(site_view.site.content_warning.is_some()); .unwrap_or(site_view.site.content_warning.is_some());
let language_tags = get_language_tags(&req);
// Create the local user // Create the local user
let local_user_form = LocalUserInsertForm { let local_user_form = LocalUserInsertForm {
email: data.email.as_deref().map(str::to_lowercase), email: data.email.as_deref().map(str::to_lowercase),
password_encrypted: data.password.to_string(),
show_nsfw: Some(show_nsfw), show_nsfw: Some(show_nsfw),
accepted_application, accepted_application,
default_listing_type: Some(local_site.default_post_listing_type), default_listing_type: Some(local_site.default_post_listing_type),
@ -158,21 +154,10 @@ pub async fn register(
interface_language: language_tags.first().cloned(), interface_language: language_tags.first().cloned(),
// If its the initial site setup, they are an admin // If its the initial site setup, they are an admin
admin: Some(!local_site.site_setup), admin: Some(!local_site.site_setup),
..LocalUserInsertForm::new(inserted_person.id, data.password.to_string()) ..LocalUserInsertForm::new(inserted_person.id, Some(data.password.to_string()))
}; };
let all_languages = Language::read_all(&mut context.pool()).await?; let inserted_local_user = create_local_user(&context, language_tags, &local_user_form).await?;
// use hashset to avoid duplicates
let mut language_ids = HashSet::new();
for l in language_tags {
if let Some(found) = all_languages.iter().find(|all| all.code == l) {
language_ids.insert(found.id);
}
}
let language_ids = language_ids.into_iter().collect();
let inserted_local_user =
LocalUser::create(&mut context.pool(), &local_user_form, language_ids).await?;
if local_site.site_setup && require_registration_application { if local_site.site_setup && require_registration_application {
// Create the registration application // Create the registration application
@ -205,29 +190,13 @@ pub async fn register(
let jwt = Claims::generate(inserted_local_user.id, req, &context).await?; let jwt = Claims::generate(inserted_local_user.id, req, &context).await?;
login_response.jwt = Some(jwt); login_response.jwt = Some(jwt);
} else { } else {
if local_site.require_email_verification { login_response.verify_email_sent = send_verification_email_if_required(
let local_user_view = LocalUserView { &context,
local_user: inserted_local_user, &local_site,
local_user_vote_display_mode: LocalUserVoteDisplayMode::default(), &inserted_local_user,
person: inserted_person, &inserted_person,
counts: PersonAggregates::default(), )
}; .await?;
// we check at the beginning of this method that email is set
let email = local_user_view
.local_user
.email
.clone()
.expect("email was provided");
send_verification_email(
&local_user_view,
&email,
&mut context.pool(),
context.settings(),
)
.await?;
login_response.verify_email_sent = true;
}
if require_registration_application { if require_registration_application {
login_response.registration_created = true; login_response.registration_created = true;
@ -236,3 +205,390 @@ pub async fn register(
Ok(Json(login_response)) Ok(Json(login_response))
} }
#[tracing::instrument(skip(context))]
pub async fn authenticate_with_oauth(
data: Json<AuthenticateWithOauth>,
req: HttpRequest,
context: Data<LemmyContext>,
) -> LemmyResult<Json<LoginResponse>> {
let site_view = SiteView::read_local(&mut context.pool()).await?;
let local_site = site_view.local_site.clone();
// validate inputs
if data.oauth_provider_id == OAuthProviderId(0) || data.code.is_empty() || data.code.len() > 300 {
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
}
// validate the redirect_uri
let redirect_uri = &data.redirect_uri;
if redirect_uri.host_str().unwrap_or("").is_empty()
|| !redirect_uri.path().eq(&String::from("/oauth/callback"))
|| !redirect_uri.query().unwrap_or("").is_empty()
{
Err(LemmyErrorType::OauthAuthorizationInvalid)?
}
// Fetch the OAUTH provider and make sure it's enabled
let oauth_provider_id = data.oauth_provider_id;
let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id)
.await
.ok()
.flatten()
.ok_or(LemmyErrorType::OauthAuthorizationInvalid)?;
if !oauth_provider.enabled {
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
}
let token_response =
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str())
.await?;
let user_info = oidc_get_user_info(
&context,
&oauth_provider,
token_response.access_token.as_str(),
)
.await?;
let oauth_user_id = read_user_info(&user_info, oauth_provider.id_claim.as_str())?;
let mut login_response = LoginResponse {
jwt: None,
registration_created: false,
verify_email_sent: false,
};
// Lookup user by oauth_user_id
let mut local_user_view =
LocalUserView::find_by_oauth_id(&mut context.pool(), oauth_provider.id, &oauth_user_id).await?;
let local_user: LocalUser;
if let Some(user_view) = local_user_view {
// user found by oauth_user_id => Login user
local_user = user_view.clone().local_user;
check_user_valid(&user_view.person)?;
check_email_verified(&user_view, &site_view)?;
check_registration_application(&user_view, &site_view.local_site, &mut context.pool()).await?;
} else {
// user has never previously registered using oauth
// prevent registration if registration is closed
if local_site.registration_mode == RegistrationMode::Closed {
Err(LemmyErrorType::RegistrationClosed)?
}
// prevent registration if registration is closed for OAUTH providers
if !local_site.oauth_registration {
return Err(LemmyErrorType::OauthRegistrationClosed)?;
}
// Extract the OAUTH email claim from the returned user_info
let email = read_user_info(&user_info, "email")?;
let require_registration_application =
local_site.registration_mode == RegistrationMode::RequireApplication;
// Lookup user by OAUTH email and link accounts
local_user_view = LocalUserView::find_by_email(&mut context.pool(), &email).await?;
let person;
if let Some(user_view) = local_user_view {
// user found by email => link and login if linking is allowed
// we only allow linking by email when email_verification is required otherwise emails cannot
// be trusted
if oauth_provider.account_linking_enabled && site_view.local_site.require_email_verification {
// WARNING:
// If an admin switches the require_email_verification config from false to true,
// users who signed up before the switch could have accounts with unverified emails falsely
// marked as verified.
check_user_valid(&user_view.person)?;
check_email_verified(&user_view, &site_view)?;
check_registration_application(&user_view, &site_view.local_site, &mut context.pool())
.await?;
// Link with OAUTH => Login user
let oauth_account_form =
OAuthAccountInsertForm::new(user_view.local_user.id, oauth_provider.id, oauth_user_id);
OAuthAccount::create(&mut context.pool(), &oauth_account_form)
.await
.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
local_user = user_view.local_user.clone();
} else {
return Err(LemmyErrorType::EmailAlreadyExists)?;
}
} else {
// No user was found by email => Register as new user
// make sure the registration answer is provided when the registration application is required
validate_registration_answer(require_registration_application, &data.answer)?;
// make sure the username is provided
let username = data
.username
.as_ref()
.ok_or(LemmyErrorType::RegistrationUsernameRequired)?;
let slur_regex = local_site_to_slur_regex(&local_site);
check_slurs(username, &slur_regex)?;
check_slurs_opt(&data.answer, &slur_regex)?;
if Person::is_username_taken(&mut context.pool(), username).await? {
return Err(LemmyErrorType::UsernameAlreadyExists)?;
}
// We have to create a person, a local_user, and an oauth_account
person = create_person(
username.clone(),
&local_site,
site_view.site.instance_id,
&context,
)
.await?;
// Show nsfw content if param is true, or if content_warning exists
let show_nsfw = data
.show_nsfw
.unwrap_or(site_view.site.content_warning.is_some());
let language_tags = get_language_tags(&req);
// Create the local user
let local_user_form = LocalUserInsertForm {
email: Some(str::to_lowercase(&email)),
show_nsfw: Some(show_nsfw),
accepted_application: Some(!require_registration_application),
email_verified: Some(oauth_provider.auto_verify_email),
post_listing_mode: Some(local_site.default_post_listing_mode),
interface_language: language_tags.first().cloned(),
// If its the initial site setup, they are an admin
admin: Some(!local_site.site_setup),
..LocalUserInsertForm::new(person.id, None)
};
local_user = create_local_user(&context, language_tags, &local_user_form).await?;
// Create the oauth account
let oauth_account_form =
OAuthAccountInsertForm::new(local_user.id, oauth_provider.id, oauth_user_id);
OAuthAccount::create(&mut context.pool(), &oauth_account_form)
.await
.map_err(|_| LemmyErrorType::IncorrectLogin)?;
// prevent sign in until application is accepted
if local_site.site_setup
&& require_registration_application
&& !local_user.accepted_application
&& !local_user.admin
{
// Create the registration application
RegistrationApplication::create(
&mut context.pool(),
&RegistrationApplicationInsertForm {
local_user_id: local_user.id,
answer: data.answer.clone().expect("must have an answer"),
},
)
.await?;
login_response.registration_created = true;
}
// Check email is verified when required
login_response.verify_email_sent =
send_verification_email_if_required(&context, &local_site, &local_user, &person).await?;
}
}
if !login_response.registration_created && !login_response.verify_email_sent {
let jwt = Claims::generate(local_user.id, req, &context).await?;
login_response.jwt = Some(jwt);
}
return Ok(Json(login_response));
}
async fn create_person(
username: String,
local_site: &LocalSite,
instance_id: InstanceId,
context: &Data<LemmyContext>,
) -> Result<Person, LemmyError> {
let actor_keypair = generate_actor_keypair()?;
is_valid_actor_name(&username, local_site.actor_name_max_length as usize)?;
let actor_id = generate_local_apub_endpoint(
EndpointType::Person,
&username,
&context.settings().get_protocol_and_hostname(),
)?;
// Register the new person
let person_form = PersonInsertForm {
actor_id: Some(actor_id.clone()),
inbox_url: Some(generate_inbox_url(&actor_id)?),
shared_inbox_url: Some(generate_shared_inbox_url(context.settings())?),
private_key: Some(actor_keypair.private_key),
..PersonInsertForm::new(username.clone(), actor_keypair.public_key, instance_id)
};
// insert the person
let inserted_person = Person::create(&mut context.pool(), &person_form)
.await
.with_lemmy_type(LemmyErrorType::UserAlreadyExists)?;
Ok(inserted_person)
}
fn get_language_tags(req: &HttpRequest) -> Vec<String> {
req
.headers()
.get("Accept-Language")
.map(|hdr| accept_language::parse(hdr.to_str().unwrap_or_default()))
.iter()
.flatten()
// Remove the optional region code
.map(|lang_str| lang_str.split('-').next().unwrap_or_default().to_string())
.collect::<Vec<String>>()
}
async fn create_local_user(
context: &Data<LemmyContext>,
language_tags: Vec<String>,
local_user_form: &LocalUserInsertForm,
) -> Result<LocalUser, LemmyError> {
let all_languages = Language::read_all(&mut context.pool()).await?;
// use hashset to avoid duplicates
let mut language_ids = HashSet::new();
for l in language_tags {
if let Some(found) = all_languages.iter().find(|all| all.code == l) {
language_ids.insert(found.id);
}
}
let language_ids = language_ids.into_iter().collect();
let inserted_local_user =
LocalUser::create(&mut context.pool(), local_user_form, language_ids).await?;
Ok(inserted_local_user)
}
async fn send_verification_email_if_required(
context: &Data<LemmyContext>,
local_site: &LocalSite,
local_user: &LocalUser,
person: &Person,
) -> LemmyResult<bool> {
let mut sent = false;
if !local_user.admin && local_site.require_email_verification && !local_user.email_verified {
let local_user_view = LocalUserView {
local_user: local_user.clone(),
local_user_vote_display_mode: LocalUserVoteDisplayMode::default(),
person: person.clone(),
counts: PersonAggregates::default(),
};
send_verification_email(
&local_user_view,
&local_user
.email
.clone()
.expect("invalid verification email"),
&mut context.pool(),
context.settings(),
)
.await?;
sent = true;
}
Ok(sent)
}
fn validate_registration_answer(
require_registration_application: bool,
answer: &Option<String>,
) -> LemmyResult<()> {
if require_registration_application && answer.is_none() {
Err(LemmyErrorType::RegistrationApplicationAnswerRequired)?
}
Ok(())
}
async fn oauth_request_access_token(
context: &Data<LemmyContext>,
oauth_provider: &OAuthProvider,
code: &str,
redirect_uri: &str,
) -> LemmyResult<TokenResponse> {
// Request an Access Token from the OAUTH provider
let response = context
.client()
.post(oauth_provider.token_endpoint.as_str())
.header("Accept", "application/json")
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", &oauth_provider.client_id),
("client_secret", &oauth_provider.client_secret),
])
.send()
.await;
let response = response.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
if !response.status().is_success() {
Err(LemmyErrorType::OauthLoginFailed)?;
}
// Extract the access token
let token_response = response
.json::<TokenResponse>()
.await
.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
Ok(token_response)
}
async fn oidc_get_user_info(
context: &Data<LemmyContext>,
oauth_provider: &OAuthProvider,
access_token: &str,
) -> LemmyResult<serde_json::Value> {
// Request the user info from the OAUTH provider
let response = context
.client()
.get(oauth_provider.userinfo_endpoint.as_str())
.header("Accept", "application/json")
.bearer_auth(access_token)
.send()
.await;
let response = response.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
if !response.status().is_success() {
Err(LemmyErrorType::OauthLoginFailed)?;
}
// Extract the OAUTH user_id claim from the returned user_info
let user_info = response
.json::<serde_json::Value>()
.await
.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
Ok(user_info)
}
fn read_user_info(user_info: &serde_json::Value, key: &str) -> LemmyResult<String> {
if let Some(value) = user_info.get(key) {
let result = serde_json::from_value::<String>(value.clone())
.map_err(|_| LemmyErrorType::OauthLoginFailed)?;
return Ok(result);
}
Err(LemmyErrorType::OauthLoginFailed)?
}

View file

@ -8,7 +8,11 @@ use lemmy_api_common::{
utils::purge_user_account, utils::purge_user_account,
SuccessResponse, SuccessResponse,
}; };
use lemmy_db_schema::source::{login_token::LoginToken, person::Person}; use lemmy_db_schema::source::{
login_token::LoginToken,
oauth_account::OAuthAccount,
person::Person,
};
use lemmy_db_views::structs::LocalUserView; use lemmy_db_views::structs::LocalUserView;
use lemmy_utils::error::{LemmyErrorType, LemmyResult}; use lemmy_utils::error::{LemmyErrorType, LemmyResult};
@ -19,11 +23,12 @@ pub async fn delete_account(
local_user_view: LocalUserView, local_user_view: LocalUserView,
) -> LemmyResult<Json<SuccessResponse>> { ) -> LemmyResult<Json<SuccessResponse>> {
// Verify the password // Verify the password
let valid: bool = verify( let valid: bool = local_user_view
&data.password, .local_user
&local_user_view.local_user.password_encrypted, .password_encrypted
) .as_ref()
.unwrap_or(false); .and_then(|password_encrypted| verify(&data.password, password_encrypted).ok())
.unwrap_or(false);
if !valid { if !valid {
Err(LemmyErrorType::IncorrectLogin)? Err(LemmyErrorType::IncorrectLogin)?
} }
@ -31,6 +36,7 @@ pub async fn delete_account(
if data.delete_content { if data.delete_content {
purge_user_account(local_user_view.person.id, &context).await?; purge_user_account(local_user_view.person.id, &context).await?;
} else { } else {
OAuthAccount::delete_user_accounts(&mut context.pool(), local_user_view.local_user.id).await?;
Person::delete_account(&mut context.pool(), local_user_view.person.id).await?; Person::delete_account(&mut context.pool(), local_user_view.person.id).await?;
} }

View file

@ -35,9 +35,11 @@ impl LocalUser {
) -> Result<LocalUser, Error> { ) -> Result<LocalUser, Error> {
let conn = &mut get_conn(pool).await?; let conn = &mut get_conn(pool).await?;
let mut form_with_encrypted_password = form.clone(); let mut form_with_encrypted_password = form.clone();
let password_hash =
hash(&form.password_encrypted, DEFAULT_COST).expect("Couldn't hash password"); if let Some(password_encrypted) = &form.password_encrypted {
form_with_encrypted_password.password_encrypted = password_hash; let password_hash = hash(password_encrypted, DEFAULT_COST).expect("Couldn't hash password");
form_with_encrypted_password.password_encrypted = Some(password_hash);
}
let local_user_ = insert_into(local_user::table) let local_user_ = insert_into(local_user::table)
.values(form_with_encrypted_password) .values(form_with_encrypted_password)
@ -346,7 +348,7 @@ impl LocalUserOptionHelper for Option<&LocalUser> {
impl LocalUserInsertForm { impl LocalUserInsertForm {
pub fn test_form(person_id: PersonId) -> Self { pub fn test_form(person_id: PersonId) -> Self {
Self::new(person_id, String::new()) Self::new(person_id, Some(String::new()))
} }
pub fn test_form_admin(person_id: PersonId) -> Self { pub fn test_form_admin(person_id: PersonId) -> Self {

View file

@ -22,6 +22,8 @@ pub mod local_user;
pub mod local_user_vote_display_mode; pub mod local_user_vote_display_mode;
pub mod login_token; pub mod login_token;
pub mod moderator; pub mod moderator;
pub mod oauth_account;
pub mod oauth_provider;
pub mod password_reset_request; pub mod password_reset_request;
pub mod person; pub mod person;
pub mod person_block; pub mod person_block;

View file

@ -0,0 +1,59 @@
use crate::{
newtypes::{LocalUserId, OAuthProviderId},
schema::{oauth_account, oauth_account::dsl::local_user_id},
source::oauth_account::{OAuthAccount, OAuthAccountInsertForm},
utils::{get_conn, DbPool},
};
use diesel::{
dsl::{exists, insert_into},
result::Error,
select,
ExpressionMethods,
QueryDsl,
};
use diesel_async::RunQueryDsl;
impl OAuthAccount {
pub async fn read(
pool: &mut DbPool<'_>,
for_oauth_provider_id: OAuthProviderId,
for_local_user_id: LocalUserId,
) -> Result<bool, Error> {
let conn = &mut get_conn(pool).await?;
select(exists(
oauth_account::table.find((for_oauth_provider_id, for_local_user_id)),
))
.get_result(conn)
.await
}
pub async fn create(pool: &mut DbPool<'_>, form: &OAuthAccountInsertForm) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
insert_into(oauth_account::table)
.values(form)
.get_result::<Self>(conn)
.await
}
pub async fn delete(
pool: &mut DbPool<'_>,
for_oauth_provider_id: OAuthProviderId,
for_local_user_id: LocalUserId,
) -> Result<usize, Error> {
let conn = &mut get_conn(pool).await?;
diesel::delete(oauth_account::table.find((for_oauth_provider_id, for_local_user_id)))
.execute(conn)
.await
}
pub async fn delete_user_accounts(
pool: &mut DbPool<'_>,
for_local_user_id: LocalUserId,
) -> Result<usize, Error> {
let conn = &mut get_conn(pool).await?;
diesel::delete(oauth_account::table.filter(local_user_id.eq(for_local_user_id)))
.execute(conn)
.await
}
}

View file

@ -0,0 +1,71 @@
use crate::{
newtypes::OAuthProviderId,
schema::oauth_provider,
source::oauth_provider::{
OAuthProvider,
OAuthProviderInsertForm,
OAuthProviderUpdateForm,
PublicOAuthProvider,
},
traits::Crud,
utils::{get_conn, DbPool},
};
use diesel::{dsl::insert_into, result::Error, QueryDsl};
use diesel_async::RunQueryDsl;
#[async_trait]
impl Crud for OAuthProvider {
type InsertForm = OAuthProviderInsertForm;
type UpdateForm = OAuthProviderUpdateForm;
type IdType = OAuthProviderId;
async fn create(pool: &mut DbPool<'_>, form: &Self::InsertForm) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
insert_into(oauth_provider::table)
.values(form)
.get_result::<Self>(conn)
.await
}
async fn update(
pool: &mut DbPool<'_>,
oauth_provider_id: OAuthProviderId,
form: &Self::UpdateForm,
) -> Result<Self, Error> {
let conn = &mut get_conn(pool).await?;
diesel::update(oauth_provider::table.find(oauth_provider_id))
.set(form)
.get_result::<Self>(conn)
.await
}
}
impl OAuthProvider {
pub async fn get_all(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
let conn = &mut get_conn(pool).await?;
let oauth_providers = oauth_provider::table
.order(oauth_provider::id)
.select(oauth_provider::all_columns)
.load::<OAuthProvider>(conn)
.await?;
Ok(oauth_providers)
}
pub fn convert_providers_to_public(
oauth_providers: Vec<OAuthProvider>,
) -> Vec<PublicOAuthProvider> {
let mut result = Vec::<PublicOAuthProvider>::new();
for oauth_provider in &oauth_providers {
if oauth_provider.enabled {
result.push(PublicOAuthProvider(oauth_provider.clone()));
}
}
result
}
pub async fn get_all_public(pool: &mut DbPool<'_>) -> Result<Vec<PublicOAuthProvider>, Error> {
let oauth_providers = OAuthProvider::get_all(pool).await?;
Ok(Self::convert_providers_to_public(oauth_providers))
}
}

View file

@ -121,6 +121,18 @@ impl Person {
.load::<CommunityId>(conn) .load::<CommunityId>(conn)
.await .await
} }
pub async fn is_username_taken(pool: &mut DbPool<'_>, username: &str) -> Result<bool, Error> {
use diesel::dsl::{exists, select};
let conn = &mut get_conn(pool).await?;
select(exists(
person::table
.filter(lower(person::name).eq(username.to_lowercase()))
.filter(person::local.eq(true)),
))
.get_result(conn)
.await
}
} }
impl PersonInsertForm { impl PersonInsertForm {

View file

@ -154,6 +154,12 @@ pub struct CustomEmojiId(i32);
/// The registration application id. /// The registration application id.
pub struct RegistrationApplicationId(i32); pub struct RegistrationApplicationId(i32);
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "full", derive(DieselNewType, TS))]
#[cfg_attr(feature = "full", ts(export))]
/// The oauth provider id.
pub struct OAuthProviderId(pub i32);
#[cfg(feature = "full")] #[cfg(feature = "full")]
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(remote = "Ltree")] #[serde(remote = "Ltree")]

View file

@ -392,6 +392,7 @@ diesel::table! {
federation_signed_fetch -> Bool, federation_signed_fetch -> Bool,
default_post_listing_mode -> PostListingModeEnum, default_post_listing_mode -> PostListingModeEnum,
default_sort_type -> SortTypeEnum, default_sort_type -> SortTypeEnum,
oauth_registration -> Bool,
} }
} }
@ -435,7 +436,7 @@ diesel::table! {
local_user (id) { local_user (id) {
id -> Int4, id -> Int4,
person_id -> Int4, person_id -> Int4,
password_encrypted -> Text, password_encrypted -> Nullable<Text>,
email -> Nullable<Text>, email -> Nullable<Text>,
show_nsfw -> Bool, show_nsfw -> Bool,
theme -> Text, theme -> Text,
@ -611,6 +612,36 @@ diesel::table! {
} }
} }
diesel::table! {
oauth_account (oauth_provider_id, local_user_id) {
local_user_id -> Int4,
oauth_provider_id -> Int4,
oauth_user_id -> Text,
published -> Timestamptz,
updated -> Nullable<Timestamptz>,
}
}
diesel::table! {
oauth_provider (id) {
id -> Int4,
display_name -> Text,
issuer -> Text,
authorization_endpoint -> Text,
token_endpoint -> Text,
userinfo_endpoint -> Text,
id_claim -> Text,
client_id -> Text,
client_secret -> Text,
scopes -> Text,
auto_verify_email -> Bool,
account_linking_enabled -> Bool,
enabled -> Bool,
published -> Timestamptz,
updated -> Nullable<Timestamptz>,
}
}
diesel::table! { diesel::table! {
password_reset_request (id) { password_reset_request (id) {
id -> Int4, id -> Int4,
@ -1003,6 +1034,8 @@ diesel::joinable!(mod_remove_community -> person (mod_person_id));
diesel::joinable!(mod_remove_post -> person (mod_person_id)); diesel::joinable!(mod_remove_post -> person (mod_person_id));
diesel::joinable!(mod_remove_post -> post (post_id)); diesel::joinable!(mod_remove_post -> post (post_id));
diesel::joinable!(mod_transfer_community -> community (community_id)); diesel::joinable!(mod_transfer_community -> community (community_id));
diesel::joinable!(oauth_account -> local_user (local_user_id));
diesel::joinable!(oauth_account -> oauth_provider (oauth_provider_id));
diesel::joinable!(password_reset_request -> local_user (local_user_id)); diesel::joinable!(password_reset_request -> local_user (local_user_id));
diesel::joinable!(person -> instance (instance_id)); diesel::joinable!(person -> instance (instance_id));
diesel::joinable!(person_aggregates -> person (person_id)); diesel::joinable!(person_aggregates -> person (person_id));
@ -1084,6 +1117,8 @@ diesel::allow_tables_to_appear_in_same_query!(
mod_remove_community, mod_remove_community,
mod_remove_post, mod_remove_post,
mod_transfer_community, mod_transfer_community,
oauth_account,
oauth_provider,
password_reset_request, password_reset_request,
person, person,
person_aggregates, person_aggregates,

View file

@ -68,6 +68,8 @@ pub struct LocalSite {
pub default_post_listing_mode: PostListingMode, pub default_post_listing_mode: PostListingMode,
/// Default value for [LocalUser.post_listing_mode] /// Default value for [LocalUser.post_listing_mode]
pub default_sort_type: SortType, pub default_sort_type: SortType,
/// Whether or not external auth methods can auto-register users.
pub oauth_registration: bool,
} }
#[derive(Clone, TypedBuilder)] #[derive(Clone, TypedBuilder)]
@ -94,6 +96,7 @@ pub struct LocalSiteInsertForm {
pub captcha_enabled: Option<bool>, pub captcha_enabled: Option<bool>,
pub captcha_difficulty: Option<String>, pub captcha_difficulty: Option<String>,
pub registration_mode: Option<RegistrationMode>, pub registration_mode: Option<RegistrationMode>,
pub oauth_registration: Option<bool>,
pub reports_email_admins: Option<bool>, pub reports_email_admins: Option<bool>,
pub federation_signed_fetch: Option<bool>, pub federation_signed_fetch: Option<bool>,
pub default_post_listing_mode: Option<PostListingMode>, pub default_post_listing_mode: Option<PostListingMode>,
@ -121,6 +124,7 @@ pub struct LocalSiteUpdateForm {
pub captcha_enabled: Option<bool>, pub captcha_enabled: Option<bool>,
pub captcha_difficulty: Option<String>, pub captcha_difficulty: Option<String>,
pub registration_mode: Option<RegistrationMode>, pub registration_mode: Option<RegistrationMode>,
pub oauth_registration: Option<bool>,
pub reports_email_admins: Option<bool>, pub reports_email_admins: Option<bool>,
pub updated: Option<Option<DateTime<Utc>>>, pub updated: Option<Option<DateTime<Utc>>>,
pub federation_signed_fetch: Option<bool>, pub federation_signed_fetch: Option<bool>,

View file

@ -24,7 +24,7 @@ pub struct LocalUser {
/// The person_id for the local user. /// The person_id for the local user.
pub person_id: PersonId, pub person_id: PersonId,
#[serde(skip)] #[serde(skip)]
pub password_encrypted: SensitiveString, pub password_encrypted: Option<SensitiveString>,
pub email: Option<SensitiveString>, pub email: Option<SensitiveString>,
/// Whether to show NSFW content. /// Whether to show NSFW content.
pub show_nsfw: bool, pub show_nsfw: bool,
@ -70,7 +70,7 @@ pub struct LocalUser {
#[cfg_attr(feature = "full", diesel(table_name = local_user))] #[cfg_attr(feature = "full", diesel(table_name = local_user))]
pub struct LocalUserInsertForm { pub struct LocalUserInsertForm {
pub person_id: PersonId, pub person_id: PersonId,
pub password_encrypted: String, pub password_encrypted: Option<String>,
#[new(default)] #[new(default)]
pub email: Option<String>, pub email: Option<String>,
#[new(default)] #[new(default)]

View file

@ -27,6 +27,8 @@ pub mod local_user;
pub mod local_user_vote_display_mode; pub mod local_user_vote_display_mode;
pub mod login_token; pub mod login_token;
pub mod moderator; pub mod moderator;
pub mod oauth_account;
pub mod oauth_provider;
pub mod password_reset_request; pub mod password_reset_request;
pub mod person; pub mod person;
pub mod person_block; pub mod person_block;

View file

@ -0,0 +1,32 @@
use crate::newtypes::{LocalUserId, OAuthProviderId};
#[cfg(feature = "full")]
use crate::schema::oauth_account;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[cfg(feature = "full")]
use ts_rs::TS;
#[skip_serializing_none]
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "full", derive(Queryable, Selectable, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_account))]
#[cfg_attr(feature = "full", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "full", ts(export))]
/// An auth account method.
pub struct OAuthAccount {
pub local_user_id: LocalUserId,
pub oauth_provider_id: OAuthProviderId,
pub oauth_user_id: String,
pub published: DateTime<Utc>,
pub updated: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, derive_new::new)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_account))]
pub struct OAuthAccountInsertForm {
pub local_user_id: LocalUserId,
pub oauth_provider_id: OAuthProviderId,
pub oauth_user_id: String,
}

View file

@ -0,0 +1,131 @@
#[cfg(feature = "full")]
use crate::schema::oauth_provider;
use crate::{
newtypes::{DbUrl, OAuthProviderId},
sensitive::SensitiveString,
};
use chrono::{DateTime, Utc};
use serde::{
ser::{SerializeStruct, Serializer},
Deserialize,
Serialize,
};
use serde_with::skip_serializing_none;
#[cfg(feature = "full")]
use ts_rs::TS;
#[skip_serializing_none]
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "full", derive(Queryable, Selectable, Identifiable, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))]
#[cfg_attr(feature = "full", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "full", ts(export))]
/// oauth provider with client_secret - should never be sent to the client
pub struct OAuthProvider {
pub id: OAuthProviderId,
/// The OAuth 2.0 provider name displayed to the user on the Login page
pub display_name: String,
/// The issuer url of the OAUTH provider.
#[cfg_attr(feature = "full", ts(type = "string"))]
pub issuer: DbUrl,
/// The authorization endpoint is used to interact with the resource owner and obtain an
/// authorization grant. This is usually provided by the OAUTH provider.
#[cfg_attr(feature = "full", ts(type = "string"))]
pub authorization_endpoint: DbUrl,
/// The token endpoint is used by the client to obtain an access token by presenting its
/// authorization grant or refresh token. This is usually provided by the OAUTH provider.
#[cfg_attr(feature = "full", ts(type = "string"))]
pub token_endpoint: DbUrl,
/// The UserInfo Endpoint is an OAuth 2.0 Protected Resource that returns Claims about the
/// authenticated End-User. This is defined in the OIDC specification.
#[cfg_attr(feature = "full", ts(type = "string"))]
pub userinfo_endpoint: DbUrl,
/// The OAuth 2.0 claim containing the unique user ID returned by the provider. Usually this
/// should be set to "sub".
pub id_claim: String,
/// The client_id is provided by the OAuth 2.0 provider and is a unique identifier to this
/// service
pub client_id: String,
/// The client_secret is provided by the OAuth 2.0 provider and is used to authenticate this
/// service with the provider
#[serde(skip)]
pub client_secret: SensitiveString,
/// Lists the scopes requested from users. Users will have to grant access to the requested scope
/// at sign up.
pub scopes: String,
/// Automatically sets email as verified on registration
pub auto_verify_email: bool,
/// Allows linking an OAUTH account to an existing user account by matching emails
pub account_linking_enabled: bool,
/// switch to enable or disable an oauth provider
pub enabled: bool,
pub published: DateTime<Utc>,
pub updated: Option<DateTime<Utc>>,
}
#[derive(Clone, PartialEq, Eq, Debug, Deserialize)]
#[serde(transparent)]
#[cfg_attr(feature = "full", derive(TS))]
#[cfg_attr(feature = "full", ts(export))]
// A subset of OAuthProvider used for public requests, for example to display the OAUTH buttons on
// the login page
pub struct PublicOAuthProvider(pub OAuthProvider);
impl Serialize for PublicOAuthProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("PublicOAuthProvider", 5)?;
state.serialize_field("id", &self.0.id)?;
state.serialize_field("display_name", &self.0.display_name)?;
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
state.serialize_field("client_id", &self.0.client_id)?;
state.serialize_field("scopes", &self.0.scopes)?;
state.end()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))]
#[cfg_attr(feature = "full", ts(export))]
pub struct OAuthProviderInsertForm {
pub display_name: String,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub issuer: DbUrl,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub authorization_endpoint: DbUrl,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub token_endpoint: DbUrl,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub userinfo_endpoint: DbUrl,
pub id_claim: String,
pub client_id: String,
pub client_secret: String,
pub scopes: String,
pub auto_verify_email: bool,
pub account_linking_enabled: bool,
pub enabled: bool,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "full", derive(Insertable, AsChangeset, TS))]
#[cfg_attr(feature = "full", diesel(table_name = oauth_provider))]
#[cfg_attr(feature = "full", ts(export))]
pub struct OAuthProviderUpdateForm {
pub display_name: Option<String>,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub authorization_endpoint: Option<DbUrl>,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub token_endpoint: Option<DbUrl>,
#[cfg_attr(feature = "full", ts(type = "string"))]
pub userinfo_endpoint: Option<DbUrl>,
pub id_claim: Option<String>,
pub client_secret: Option<String>,
pub scopes: Option<String>,
pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>,
pub enabled: Option<bool>,
pub updated: Option<Option<DateTime<Utc>>>,
}

View file

@ -288,7 +288,7 @@ pub fn is_email_regex(test: &str) -> bool {
EMAIL_REGEX.is_match(test) EMAIL_REGEX.is_match(test)
} }
/// Takes an API text input, and converts it to an optional diesel DB update. /// Takes an API optional text input, and converts it to an optional diesel DB update.
pub fn diesel_string_update(opt: Option<&str>) -> Option<Option<String>> { pub fn diesel_string_update(opt: Option<&str>) -> Option<Option<String>> {
match opt { match opt {
// An empty string is an erase // An empty string is an erase
@ -298,6 +298,17 @@ pub fn diesel_string_update(opt: Option<&str>) -> Option<Option<String>> {
} }
} }
/// Takes an API optional text input, and converts it to an optional diesel DB update (for non
/// nullable properties).
pub fn diesel_required_string_update(opt: Option<&str>) -> Option<String> {
match opt {
// An empty string is no change
Some("") => None,
Some(str) => Some(str.into()),
None => None,
}
}
/// Takes an optional API URL-type input, and converts it to an optional diesel DB update. /// Takes an optional API URL-type input, and converts it to an optional diesel DB update.
/// Also cleans the url params. /// Also cleans the url params.
pub fn diesel_url_update(opt: Option<&str>) -> LemmyResult<Option<Option<DbUrl>>> { pub fn diesel_url_update(opt: Option<&str>) -> LemmyResult<Option<Option<DbUrl>>> {
@ -311,6 +322,19 @@ pub fn diesel_url_update(opt: Option<&str>) -> LemmyResult<Option<Option<DbUrl>>
} }
} }
/// Takes an optional API URL-type input, and converts it to an optional diesel DB update (for non
/// nullable properties). Also cleans the url params.
pub fn diesel_required_url_update(opt: Option<&str>) -> LemmyResult<Option<DbUrl>> {
match opt {
// An empty string is no change
Some("") => Ok(None),
Some(str_url) => Url::parse(str_url)
.map(|u| Some(clean_url(&u).into()))
.with_lemmy_type(LemmyErrorType::InvalidUrl),
None => Ok(None),
}
}
/// Takes an optional API URL-type input, and converts it to an optional diesel DB create. /// Takes an optional API URL-type input, and converts it to an optional diesel DB create.
/// Also cleans the url params. /// Also cleans the url params.
pub fn diesel_url_create(opt: Option<&str>) -> LemmyResult<Option<DbUrl>> { pub fn diesel_url_create(opt: Option<&str>) -> LemmyResult<Option<DbUrl>> {

View file

@ -3,8 +3,8 @@ use actix_web::{dev::Payload, FromRequest, HttpMessage, HttpRequest};
use diesel::{result::Error, BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl}; use diesel::{result::Error, BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use lemmy_db_schema::{ use lemmy_db_schema::{
newtypes::{LocalUserId, PersonId}, newtypes::{LocalUserId, OAuthProviderId, PersonId},
schema::{local_user, local_user_vote_display_mode, person, person_aggregates}, schema::{local_user, local_user_vote_display_mode, oauth_account, person, person_aggregates},
utils::{ utils::{
functions::{coalesce, lower}, functions::{coalesce, lower},
DbConn, DbConn,
@ -23,6 +23,7 @@ enum ReadBy<'a> {
Name(&'a str), Name(&'a str),
NameOrEmail(&'a str), NameOrEmail(&'a str),
Email(&'a str), Email(&'a str),
OAuthID(OAuthProviderId, &'a str),
} }
enum ListMode { enum ListMode {
@ -58,12 +59,21 @@ fn queries<'a>(
), ),
_ => query, _ => query,
}; };
query let query = query
.inner_join(local_user_vote_display_mode::table) .inner_join(local_user_vote_display_mode::table)
.inner_join(person_aggregates::table.on(person::id.eq(person_aggregates::person_id))) .inner_join(person_aggregates::table.on(person::id.eq(person_aggregates::person_id)));
.select(selection)
.first(&mut conn) if let ReadBy::OAuthID(oauth_provider_id, oauth_user_id) = search {
.await query
.inner_join(oauth_account::table)
.filter(oauth_account::oauth_provider_id.eq(oauth_provider_id))
.filter(oauth_account::oauth_user_id.eq(oauth_user_id))
.select(selection)
.first(&mut conn)
.await
} else {
query.select(selection).first(&mut conn).await
}
}; };
let list = move |mut conn: DbConn<'a>, mode: ListMode| async move { let list = move |mut conn: DbConn<'a>, mode: ListMode| async move {
@ -120,6 +130,16 @@ impl LocalUserView {
queries().read(pool, ReadBy::Email(from_email)).await queries().read(pool, ReadBy::Email(from_email)).await
} }
pub async fn find_by_oauth_id(
pool: &mut DbPool<'_>,
oauth_provider_id: OAuthProviderId,
oauth_user_id: &str,
) -> Result<Option<Self>, Error> {
queries()
.read(pool, ReadBy::OAuthID(oauth_provider_id, oauth_user_id))
.await
}
pub async fn list_admins_with_emails(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> { pub async fn list_admins_with_emails(pool: &mut DbPool<'_>) -> Result<Vec<Self>, Error> {
queries().list(pool, ListMode::AdminsWithEmails).await queries().list(pool, ListMode::AdminsWithEmails).await
} }

View file

@ -5,8 +5,7 @@ use actix_web::{
Method, Method,
StatusCode, StatusCode,
}, },
web, web::{self, Query},
web::Query,
HttpRequest, HttpRequest,
HttpResponse, HttpResponse,
}; };

View file

@ -55,6 +55,7 @@ pub enum LemmyErrorType {
CouldntFindCommentReply, CouldntFindCommentReply,
CouldntFindPrivateMessage, CouldntFindPrivateMessage,
CouldntFindActivity, CouldntFindActivity,
CouldntFindOauthProvider,
PersonIsBlocked, PersonIsBlocked,
CommunityIsBlocked, CommunityIsBlocked,
InstanceIsBlocked, InstanceIsBlocked,
@ -83,7 +84,9 @@ pub enum LemmyErrorType {
InvalidDefaultPostListingType, InvalidDefaultPostListingType,
RegistrationClosed, RegistrationClosed,
RegistrationApplicationAnswerRequired, RegistrationApplicationAnswerRequired,
RegistrationUsernameRequired,
EmailAlreadyExists, EmailAlreadyExists,
UsernameAlreadyExists,
FederationForbiddenByStrictAllowList, FederationForbiddenByStrictAllowList,
PersonIsBannedFromCommunity, PersonIsBannedFromCommunity,
ObjectIsNotPublic, ObjectIsNotPublic,
@ -178,6 +181,10 @@ pub enum LemmyErrorType {
CantBlockLocalInstance, CantBlockLocalInstance,
UrlWithoutDomain, UrlWithoutDomain,
InboxTimeout, InboxTimeout,
OauthAuthorizationInvalid,
OauthLoginFailed,
OauthRegistrationClosed,
CouldntDeleteOauthProvider,
Unknown(String), Unknown(String),
CantDeleteSite, CantDeleteSite,
UrlLengthOverflow, UrlLengthOverflow,

View file

@ -0,0 +1,10 @@
DROP TABLE oauth_account;
DROP TABLE oauth_provider;
ALTER TABLE local_site
DROP COLUMN oauth_registration;
ALTER TABLE local_user
ALTER COLUMN password_encrypted SET NOT NULL;

View file

@ -0,0 +1,34 @@
ALTER TABLE local_user
ALTER COLUMN password_encrypted DROP NOT NULL;
CREATE TABLE oauth_provider (
id serial PRIMARY KEY,
display_name text NOT NULL,
issuer text NOT NULL,
authorization_endpoint text NOT NULL,
token_endpoint text NOT NULL,
userinfo_endpoint text NOT NULL,
id_claim text NOT NULL,
client_id text NOT NULL UNIQUE,
client_secret text NOT NULL,
scopes text NOT NULL,
auto_verify_email boolean DEFAULT TRUE NOT NULL,
account_linking_enabled boolean DEFAULT FALSE NOT NULL,
enabled boolean DEFAULT FALSE NOT NULL,
published timestamp with time zone DEFAULT now() NOT NULL,
updated timestamp with time zone
);
ALTER TABLE local_site
ADD COLUMN oauth_registration boolean DEFAULT FALSE NOT NULL;
CREATE TABLE oauth_account (
local_user_id int REFERENCES local_user ON UPDATE CASCADE ON DELETE CASCADE NOT NULL,
oauth_provider_id int REFERENCES oauth_provider ON UPDATE CASCADE ON DELETE RESTRICT NOT NULL,
oauth_user_id text NOT NULL,
published timestamp with time zone DEFAULT now() NOT NULL,
updated timestamp with time zone,
UNIQUE (oauth_provider_id, oauth_user_id),
PRIMARY KEY (oauth_provider_id, local_user_id)
);

View file

@ -109,6 +109,11 @@ use lemmy_api_crud::{
delete::delete_custom_emoji, delete::delete_custom_emoji,
update::update_custom_emoji, update::update_custom_emoji,
}, },
oauth_provider::{
create::create_oauth_provider,
delete::delete_oauth_provider,
update::update_oauth_provider,
},
post::{ post::{
create::create_post, create::create_post,
delete::delete_post, delete::delete_post,
@ -123,7 +128,10 @@ use lemmy_api_crud::{
update::update_private_message, update::update_private_message,
}, },
site::{create::create_site, read::get_site, update::update_site}, site::{create::create_site, read::get_site, update::update_site},
user::{create::register, delete::delete_account}, user::{
create::{authenticate_with_oauth, register},
delete::delete_account,
},
}; };
use lemmy_apub::api::{ use lemmy_apub::api::{
list_comments::list_comments, list_comments::list_comments,
@ -381,6 +389,18 @@ pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimitCell) {
.route("", web::post().to(create_custom_emoji)) .route("", web::post().to(create_custom_emoji))
.route("", web::put().to(update_custom_emoji)) .route("", web::put().to(update_custom_emoji))
.route("/delete", web::post().to(delete_custom_emoji)), .route("/delete", web::post().to(delete_custom_emoji)),
)
.service(
web::scope("/oauth_provider")
.wrap(rate_limit.message())
.route("", web::post().to(create_oauth_provider))
.route("", web::put().to(update_oauth_provider))
.route("/delete", web::post().to(delete_oauth_provider)),
)
.service(
web::scope("/oauth")
.wrap(rate_limit.register())
.route("/authenticate", web::post().to(authenticate_with_oauth)),
), ),
); );
cfg.service( cfg.service(

View file

@ -471,7 +471,7 @@ async fn initialize_local_site_2022_10_10(
let local_user_form = LocalUserInsertForm { let local_user_form = LocalUserInsertForm {
email: setup.admin_email.clone(), email: setup.admin_email.clone(),
admin: Some(true), admin: Some(true),
..LocalUserInsertForm::new(person_inserted.id, setup.admin_password.clone()) ..LocalUserInsertForm::new(person_inserted.id, Some(setup.admin_password.clone()))
}; };
LocalUser::create(pool, &local_user_form, vec![]).await?; LocalUser::create(pool, &local_user_form, vec![]).await?;
}; };