feat: Add try_from attribute for FromRow (#1081)

This commit is contained in:
zz 2022-09-07 12:04:11 +08:00 committed by GitHub
parent 18a76fbdbf
commit ddffaa7dde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 15 deletions

View file

@ -149,6 +149,32 @@ use crate::row::Row;
/// }
/// }
/// ```
///
/// #### `try_from`
///
/// When your struct contains a field whose type is not matched with the database type,
/// if the field type has an implementation [`TryFrom`] for the database type,
/// you can use the `try_from` attribute to convert the database type to the field type.
/// For example:
///
/// ```rust,ignore
/// #[derive(sqlx::FromRow)]
/// struct User {
/// id: i32,
/// name: String,
/// #[sqlx(try_from = "i64")]
/// bigIntInMySql: u64
/// }
/// ```
///
/// Given a query such as:
///
/// ```sql
/// SELECT id, name, bigIntInMySql FROM users;
/// ```
///
/// In MySql, `BigInt` type matches `i64`, but you can convert it to `u64` by `try_from`.
///
pub trait FromRow<'r, R: Row>: Sized {
fn from_row(row: &'r R) -> Result<Self, Error>;
}

View file

@ -71,6 +71,7 @@ pub struct SqlxChildAttributes {
pub rename: Option<String>,
pub default: bool,
pub flatten: bool,
pub try_from: Option<Ident>,
}
pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
@ -178,6 +179,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContai
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
let mut rename = None;
let mut default = false;
let mut try_from = None;
let mut flatten = false;
for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) {
@ -194,6 +196,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value),
Meta::Path(path) if path.is_ident("default") => default = true,
Meta::Path(path) if path.is_ident("flatten") => flatten = true,
u => fail!(u, "unexpected attribute"),
@ -208,6 +215,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
rename,
default,
flatten,
try_from,
})
}

View file

@ -72,22 +72,45 @@ fn expand_derive_from_row_struct(
let attributes = parse_child_attributes(&field.attrs).unwrap();
let ty = &field.ty;
let expr: Expr = if attributes.flatten {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
} else {
predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
let expr: Expr = match (attributes.flatten, attributes.try_from) {
(true, None) => {
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
}
(false, None) => {
predicates
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s))
}
(true,Some(try_from)) => {
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
(false,Some(try_from)) => {
predicates
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));
let id_s = attributes
.rename
.or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned()))
.map(|s| match container_attributes.rename_all {
Some(pattern) => rename_all(&s, pattern),
None => s,
})
.unwrap();
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
}
};
if attributes.default {

View file

@ -354,4 +354,85 @@ async fn test_column_override_exact_enum() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn test_try_from_attr_for_native_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: u64,
}
let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;
let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, id.0 as u64);
Ok(())
}
#[sqlx_macros::test]
async fn test_try_from_attr_for_custom_type() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "i64")]
id: Id,
}
#[derive(Debug, PartialEq)]
struct Id(i64);
impl std::convert::TryFrom<i64> for Id {
type Error = std::io::Error;
fn try_from(value: i64) -> Result<Self, Self::Error> {
Ok(Id(value))
}
}
let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;
let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Id(id.0));
Ok(())
}
#[sqlx_macros::test]
async fn test_try_from_attr_with_flatten() -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Record {
#[sqlx(try_from = "Id", flatten)]
id: u64,
}
#[derive(Debug, PartialEq, sqlx::FromRow)]
struct Id {
id: i64,
}
impl std::convert::TryFrom<Id> for u64 {
type Error = std::io::Error;
fn try_from(value: Id) -> Result<Self, Self::Error> {
Ok(value.id as u64)
}
}
let mut conn = new::<MySql>().await?;
let (mut conn, id) = with_test_row(&mut conn).await?;
let record = sqlx::query_as::<_, Record>("select id from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, id.0 as u64);
Ok(())
}
// we don't emit bind parameter type-checks for MySQL so testing the overrides is redundant