mirror of
https://github.com/nushell/nushell
synced 2024-12-27 13:33:16 +00:00
Series arithmetic (#3602)
* operations with series * contains operations with series * Checked division and masked operations
This commit is contained in:
parent
1d7c909080
commit
c4163c3621
13 changed files with 812 additions and 90 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -207,7 +207,7 @@ checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arrow"
|
name = "arrow"
|
||||||
version = "5.0.0-SNAPSHOT"
|
version = "5.0.0-SNAPSHOT"
|
||||||
source = "git+https://github.com/apache/arrow-rs?rev=f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570#f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570"
|
source = "git+https://github.com/apache/arrow-rs?rev=0f55b828883b3b3afda43ae404b130d374e6f1a1#0f55b828883b3b3afda43ae404b130d374e6f1a1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chrono",
|
"chrono",
|
||||||
"csv",
|
"csv",
|
||||||
|
@ -3504,6 +3504,7 @@ dependencies = [
|
||||||
"num-bigint 0.3.2",
|
"num-bigint 0.3.2",
|
||||||
"num-format",
|
"num-format",
|
||||||
"num-traits 0.2.14",
|
"num-traits 0.2.14",
|
||||||
|
"polars",
|
||||||
"query_interface",
|
"query_interface",
|
||||||
"serde 1.0.126",
|
"serde 1.0.126",
|
||||||
"sha2 0.9.5",
|
"sha2 0.9.5",
|
||||||
|
@ -4362,7 +4363,7 @@ dependencies = [
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parquet"
|
name = "parquet"
|
||||||
version = "5.0.0-SNAPSHOT"
|
version = "5.0.0-SNAPSHOT"
|
||||||
source = "git+https://github.com/apache/arrow-rs?rev=f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570#f26ffb3091ae355d246edc4a6fcc2c8e5b9bc570"
|
source = "git+https://github.com/apache/arrow-rs?rev=0f55b828883b3b3afda43ae404b130d374e6f1a1#0f55b828883b3b3afda43ae404b130d374e6f1a1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"base64 0.13.0",
|
"base64 0.13.0",
|
||||||
|
@ -4600,8 +4601,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polars"
|
name = "polars"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"polars-core",
|
"polars-core",
|
||||||
"polars-io",
|
"polars-io",
|
||||||
|
@ -4610,8 +4611,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polars-arrow"
|
name = "polars-arrow"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arrow",
|
"arrow",
|
||||||
"num 0.4.0",
|
"num 0.4.0",
|
||||||
|
@ -4620,8 +4621,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polars-core"
|
name = "polars-core"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
@ -4646,8 +4647,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polars-io"
|
name = "polars-io"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
@ -4669,8 +4670,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "polars-lazy"
|
name = "polars-lazy"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
source = "git+https://github.com/pola-rs/polars?rev=a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd#a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
source = "git+https://github.com/pola-rs/polars?rev=9e1506cca9fb646fc55f949ab6345290c3d198a7#9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash",
|
||||||
"itertools",
|
"itertools",
|
||||||
|
|
|
@ -156,6 +156,7 @@ table-pager = ["nu-command/table-pager"]
|
||||||
|
|
||||||
#dataframe feature for nushell
|
#dataframe feature for nushell
|
||||||
dataframe = [
|
dataframe = [
|
||||||
|
"nu-engine/dataframe",
|
||||||
"nu-protocol/dataframe",
|
"nu-protocol/dataframe",
|
||||||
"nu-command/dataframe",
|
"nu-command/dataframe",
|
||||||
"nu-value-ext/dataframe",
|
"nu-value-ext/dataframe",
|
||||||
|
|
|
@ -101,10 +101,10 @@ zip = { version = "0.5.9", optional = true }
|
||||||
|
|
||||||
[dependencies.polars]
|
[dependencies.polars]
|
||||||
git = "https://github.com/pola-rs/polars"
|
git = "https://github.com/pola-rs/polars"
|
||||||
rev = "a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
optional = true
|
optional = true
|
||||||
features = ["parquet", "json", "random", "pivot"]
|
features = ["parquet", "json", "random", "pivot", "strings"]
|
||||||
|
|
||||||
[target.'cfg(unix)'.dependencies]
|
[target.'cfg(unix)'.dependencies]
|
||||||
umask = "1.0.0"
|
umask = "1.0.0"
|
||||||
|
|
|
@ -12,7 +12,7 @@ pub struct DataFrame;
|
||||||
|
|
||||||
impl WholeStreamCommand for DataFrame {
|
impl WholeStreamCommand for DataFrame {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"pls groupby"
|
"pls group-by"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn usage(&self) -> &str {
|
fn usage(&self) -> &str {
|
||||||
|
@ -20,7 +20,7 @@ impl WholeStreamCommand for DataFrame {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn signature(&self) -> Signature {
|
fn signature(&self) -> Signature {
|
||||||
Signature::build("pls groupby").required(
|
Signature::build("pls group-by").required(
|
||||||
"by columns",
|
"by columns",
|
||||||
SyntaxShape::Table,
|
SyntaxShape::Table,
|
||||||
"groupby columns",
|
"groupby columns",
|
||||||
|
@ -34,7 +34,7 @@ impl WholeStreamCommand for DataFrame {
|
||||||
fn examples(&self) -> Vec<Example> {
|
fn examples(&self) -> Vec<Example> {
|
||||||
vec![Example {
|
vec![Example {
|
||||||
description: "Grouping by column a",
|
description: "Grouping by column a",
|
||||||
example: "[[a b]; [one 1] [one 2]] | pls to-df | pls groupby [a]",
|
example: "[[a b]; [one 1] [one 2]] | pls to-df | pls group-by [a]",
|
||||||
result: None,
|
result: None,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,11 +4,11 @@ use nu_errors::ShellError;
|
||||||
use nu_protocol::{
|
use nu_protocol::{
|
||||||
dataframe::NuDataFrame,
|
dataframe::NuDataFrame,
|
||||||
hir::{CapturedBlock, ClassifiedCommand, Expression, Literal, Operator, SpannedExpression},
|
hir::{CapturedBlock, ClassifiedCommand, Expression, Literal, Operator, SpannedExpression},
|
||||||
Primitive, Signature, SyntaxShape, UnspannedPathMember, UntaggedValue,
|
Primitive, Signature, SyntaxShape, UnspannedPathMember, UntaggedValue, Value,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::utils::parse_polars_error;
|
use super::utils::parse_polars_error;
|
||||||
use polars::prelude::{ChunkCompare, Series};
|
use polars::prelude::{ChunkCompare, DataType, Series};
|
||||||
|
|
||||||
pub struct DataFrame;
|
pub struct DataFrame;
|
||||||
|
|
||||||
|
@ -91,22 +91,8 @@ fn command(args: CommandArgs) -> Result<OutputStream, ShellError> {
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let rhs = evaluate_baseline_expr(&expression.right, &args.args.context)?;
|
let rhs = evaluate_baseline_expr(&expression.right, &args.args.context)?;
|
||||||
let right_condition = match &rhs.value {
|
|
||||||
UntaggedValue::Primitive(primitive) => Ok(primitive),
|
|
||||||
_ => Err(ShellError::labeled_error(
|
|
||||||
"Incorrect argument",
|
|
||||||
"Expected primitive values",
|
|
||||||
&rhs.tag.span,
|
|
||||||
)),
|
|
||||||
}?;
|
|
||||||
|
|
||||||
filter_dataframe(
|
filter_dataframe(args, &col_name, &col_name_span, &rhs, &expression.op)
|
||||||
args,
|
|
||||||
&col_name,
|
|
||||||
&col_name_span,
|
|
||||||
&right_condition,
|
|
||||||
&expression.op,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! comparison_arm {
|
macro_rules! comparison_arm {
|
||||||
|
@ -145,16 +131,25 @@ fn filter_dataframe(
|
||||||
mut args: EvaluatedCommandArgs,
|
mut args: EvaluatedCommandArgs,
|
||||||
col_name: &str,
|
col_name: &str,
|
||||||
col_name_span: &Span,
|
col_name_span: &Span,
|
||||||
right_condition: &Primitive,
|
rhs: &Value,
|
||||||
operator: &SpannedExpression,
|
operator: &SpannedExpression,
|
||||||
) -> Result<OutputStream, ShellError> {
|
) -> Result<OutputStream, ShellError> {
|
||||||
|
let right_condition = match &rhs.value {
|
||||||
|
UntaggedValue::Primitive(primitive) => Ok(primitive),
|
||||||
|
_ => Err(ShellError::labeled_error(
|
||||||
|
"Incorrect argument",
|
||||||
|
"Expected primitive values",
|
||||||
|
&rhs.tag.span,
|
||||||
|
)),
|
||||||
|
}?;
|
||||||
|
|
||||||
let span = args.call_info.name_tag.span;
|
let span = args.call_info.name_tag.span;
|
||||||
let df = NuDataFrame::try_from_stream(&mut args.input, &span)?;
|
let df = NuDataFrame::try_from_stream(&mut args.input, &span)?;
|
||||||
|
|
||||||
let col = df
|
let col = df
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.column(col_name)
|
.column(col_name)
|
||||||
.map_err(|e| parse_polars_error::<&str>(&e, &col_name_span, None))?;
|
.map_err(|e| parse_polars_error::<&str>(&e, col_name_span, None))?;
|
||||||
|
|
||||||
let op = match &operator.expr {
|
let op = match &operator.expr {
|
||||||
Expression::Literal(Literal::Operator(op)) => Ok(op),
|
Expression::Literal(Literal::Operator(op)) => Ok(op),
|
||||||
|
@ -176,6 +171,33 @@ fn filter_dataframe(
|
||||||
Operator::GreaterThanOrEqual => {
|
Operator::GreaterThanOrEqual => {
|
||||||
comparison_arm!(Series::gt_eq, col, right_condition, operator.span)
|
comparison_arm!(Series::gt_eq, col, right_condition, operator.span)
|
||||||
}
|
}
|
||||||
|
Operator::Contains => match col.dtype() {
|
||||||
|
DataType::Utf8 => match right_condition {
|
||||||
|
Primitive::String(pat) => {
|
||||||
|
let casted = col.utf8().map_err(|e| {
|
||||||
|
parse_polars_error::<&str>(&e, &args.call_info.name_tag.span, None)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
casted.contains(pat).map_err(|e| {
|
||||||
|
parse_polars_error::<&str>(&e, &args.call_info.name_tag.span, None)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => Err(ShellError::labeled_error_with_secondary(
|
||||||
|
"Incorrect argument",
|
||||||
|
"Can't perform contains with this value",
|
||||||
|
&rhs.tag.span,
|
||||||
|
"Contains only works with strings",
|
||||||
|
&rhs.tag.span,
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
_ => Err(ShellError::labeled_error_with_secondary(
|
||||||
|
"Incorrect datatype",
|
||||||
|
format!("The selected column is of type '{}'", col.dtype()),
|
||||||
|
col_name_span,
|
||||||
|
"Perhaps you want to select a column of 'str' type",
|
||||||
|
col_name_span,
|
||||||
|
)),
|
||||||
|
},
|
||||||
_ => Err(ShellError::labeled_error(
|
_ => Err(ShellError::labeled_error(
|
||||||
"Incorrect operator",
|
"Incorrect operator",
|
||||||
"Not implemented operator for dataframes filter",
|
"Not implemented operator for dataframes filter",
|
||||||
|
|
|
@ -37,10 +37,17 @@ nu-test-support = { version = "0.32.1", path = "../nu-test-support" }
|
||||||
nu-value-ext = { version = "0.32.1", path = "../nu-value-ext" }
|
nu-value-ext = { version = "0.32.1", path = "../nu-value-ext" }
|
||||||
nu-ansi-term = { version = "0.32.1", path = "../nu-ansi-term" }
|
nu-ansi-term = { version = "0.32.1", path = "../nu-ansi-term" }
|
||||||
|
|
||||||
|
[dependencies.polars]
|
||||||
|
git = "https://github.com/pola-rs/polars"
|
||||||
|
rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
|
version = "0.14.1"
|
||||||
|
optional = true
|
||||||
|
features = ["strings", "checked_arithmetic"]
|
||||||
|
|
||||||
[target.'cfg(unix)'.dependencies]
|
[target.'cfg(unix)'.dependencies]
|
||||||
users = "0.11.0"
|
users = "0.11.0"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
directories = ["directories-next"]
|
directories = ["directories-next"]
|
||||||
dirs = ["dirs-next"]
|
dirs = ["dirs-next"]
|
||||||
dataframe = ["nu-protocol/dataframe"]
|
dataframe = ["nu-protocol/dataframe", "polars"]
|
||||||
|
|
717
crates/nu-data/src/dataframe.rs
Normal file
717
crates/nu-data/src/dataframe.rs
Normal file
|
@ -0,0 +1,717 @@
|
||||||
|
use bigdecimal::BigDecimal;
|
||||||
|
use nu_errors::ShellError;
|
||||||
|
use nu_protocol::hir::Operator;
|
||||||
|
use nu_protocol::{
|
||||||
|
dataframe::{NuSeries, PolarsData},
|
||||||
|
Primitive, ShellTypeName, UntaggedValue, Value,
|
||||||
|
};
|
||||||
|
use nu_source::Span;
|
||||||
|
use num_traits::ToPrimitive;
|
||||||
|
|
||||||
|
use num_bigint::BigInt;
|
||||||
|
use polars::prelude::{
|
||||||
|
BooleanType, ChunkCompare, ChunkedArray, DataType, Float64Type, Int64Type, IntoSeries,
|
||||||
|
NumOpsDispatchChecked, Series,
|
||||||
|
};
|
||||||
|
use std::ops::{Add, BitAnd, BitOr, Div, Mul, Sub};
|
||||||
|
|
||||||
|
pub fn compute_between_series(
|
||||||
|
operator: Operator,
|
||||||
|
left: &Value,
|
||||||
|
right: &Value,
|
||||||
|
) -> Result<UntaggedValue, (&'static str, &'static str)> {
|
||||||
|
if let (
|
||||||
|
UntaggedValue::DataFrame(PolarsData::Series(lhs)),
|
||||||
|
UntaggedValue::DataFrame(PolarsData::Series(rhs)),
|
||||||
|
) = (&left.value, &right.value)
|
||||||
|
{
|
||||||
|
if lhs.as_ref().dtype() != rhs.as_ref().dtype() {
|
||||||
|
return Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Mixed datatypes",
|
||||||
|
"this datatype does not match the right hand side datatype",
|
||||||
|
&left.tag.span,
|
||||||
|
format!(
|
||||||
|
"Perhaps you want to change this datatype to '{}'",
|
||||||
|
lhs.as_ref().dtype()
|
||||||
|
),
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if lhs.as_ref().len() != rhs.as_ref().len() {
|
||||||
|
return Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Different length",
|
||||||
|
"this column length does not match the right hand column length",
|
||||||
|
&left.tag.span,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
match operator {
|
||||||
|
Operator::Plus => {
|
||||||
|
let mut res = lhs.as_ref() + rhs.as_ref();
|
||||||
|
let name = format!("sum_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::Minus => {
|
||||||
|
let mut res = lhs.as_ref() - rhs.as_ref();
|
||||||
|
let name = format!("sub_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::Multiply => {
|
||||||
|
let mut res = lhs.as_ref() * rhs.as_ref();
|
||||||
|
let name = format!("mul_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::Divide => {
|
||||||
|
let res = lhs.as_ref().checked_div(rhs.as_ref());
|
||||||
|
match res {
|
||||||
|
Ok(mut res) => {
|
||||||
|
let name = format!("div_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Err(e) => Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Division error",
|
||||||
|
format!("{}", e),
|
||||||
|
&left.tag.span,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Operator::Equal => {
|
||||||
|
let mut res = Series::eq(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("eq_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::NotEqual => {
|
||||||
|
let mut res = Series::neq(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("neq_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::LessThan => {
|
||||||
|
let mut res = Series::lt(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("lt_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::LessThanOrEqual => {
|
||||||
|
let mut res = Series::lt_eq(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("lte_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::GreaterThan => {
|
||||||
|
let mut res = Series::gt(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("gt_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::GreaterThanOrEqual => {
|
||||||
|
let mut res = Series::gt_eq(lhs.as_ref(), rhs.as_ref()).into_series();
|
||||||
|
let name = format!("gte_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
Operator::And => match lhs.as_ref().dtype() {
|
||||||
|
DataType::Boolean => {
|
||||||
|
let lhs_cast = lhs.as_ref().bool();
|
||||||
|
let rhs_cast = rhs.as_ref().bool();
|
||||||
|
|
||||||
|
match (lhs_cast, rhs_cast) {
|
||||||
|
(Ok(l), Ok(r)) => {
|
||||||
|
let mut res = l.bitand(r).into_series();
|
||||||
|
let name =
|
||||||
|
format!("and_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Casting error",
|
||||||
|
"unable to cast to boolean",
|
||||||
|
&left.tag.span,
|
||||||
|
"unable to cast to boolean",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Incorrect datatype",
|
||||||
|
"And operation can only be done with boolean values",
|
||||||
|
&left.tag.span,
|
||||||
|
))),
|
||||||
|
},
|
||||||
|
Operator::Or => match lhs.as_ref().dtype() {
|
||||||
|
DataType::Boolean => {
|
||||||
|
let lhs_cast = lhs.as_ref().bool();
|
||||||
|
let rhs_cast = rhs.as_ref().bool();
|
||||||
|
|
||||||
|
match (lhs_cast, rhs_cast) {
|
||||||
|
(Ok(l), Ok(r)) => {
|
||||||
|
let mut res = l.bitor(r).into_series();
|
||||||
|
let name =
|
||||||
|
format!("or_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
||||||
|
res.rename(name.as_ref());
|
||||||
|
Ok(NuSeries::series_to_untagged(res))
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Casting error",
|
||||||
|
"unable to cast to boolean",
|
||||||
|
&left.tag.span,
|
||||||
|
"unable to cast to boolean",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Incorrect datatype",
|
||||||
|
"And operation can only be done with boolean values",
|
||||||
|
&left.tag.span,
|
||||||
|
))),
|
||||||
|
},
|
||||||
|
_ => Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Incorrect datatype",
|
||||||
|
"unable to use this datatype for this operation",
|
||||||
|
&left.tag.span,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err((left.type_name(), right.type_name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compute_series_single_value(
|
||||||
|
operator: Operator,
|
||||||
|
left: &Value,
|
||||||
|
right: &Value,
|
||||||
|
) -> Result<UntaggedValue, (&'static str, &'static str)> {
|
||||||
|
if let (UntaggedValue::DataFrame(PolarsData::Series(lhs)), UntaggedValue::Primitive(_)) =
|
||||||
|
(&left.value, &right.value)
|
||||||
|
{
|
||||||
|
match operator {
|
||||||
|
Operator::Plus => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::add,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::add,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Float64Type>>::add,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to sum this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only int, bigInt or decimal values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::Minus => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::sub,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::sub,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Float64Type>>::sub,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to subtract this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only int, bigInt or decimal values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::Multiply => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::mul,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::mul,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Float64Type>>::mul,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to multiply this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only int, bigInt or decimal values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::Divide => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => {
|
||||||
|
if *val == 0 {
|
||||||
|
Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Division by zero",
|
||||||
|
"Zero value found",
|
||||||
|
&right.tag.span,
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(compute_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::div,
|
||||||
|
&left.tag.span,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => {
|
||||||
|
if val.eq(&0.into()) {
|
||||||
|
Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Division by zero",
|
||||||
|
"Zero value found",
|
||||||
|
&right.tag.span,
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(compute_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Int64Type>>::div,
|
||||||
|
&left.tag.span,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => {
|
||||||
|
if val.eq(&0.into()) {
|
||||||
|
Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Division by zero",
|
||||||
|
"Zero value found",
|
||||||
|
&right.tag.span,
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(compute_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
<&ChunkedArray<Float64Type>>::div,
|
||||||
|
&left.tag.span,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to divide this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::Equal => {
|
||||||
|
match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(
|
||||||
|
compare_series_decimal(lhs.as_ref(), val, ChunkedArray::eq, &left.tag.span),
|
||||||
|
),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Operator::NotEqual => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::neq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::neq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::neq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::LessThan => {
|
||||||
|
match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::lt,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::lt,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(
|
||||||
|
compare_series_decimal(lhs.as_ref(), val, ChunkedArray::lt, &left.tag.span),
|
||||||
|
),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Operator::LessThanOrEqual => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::lt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::lt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::lt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::GreaterThan => {
|
||||||
|
match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::gt,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::gt,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(
|
||||||
|
compare_series_decimal(lhs.as_ref(), val, ChunkedArray::gt, &left.tag.span),
|
||||||
|
),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Operator::GreaterThanOrEqual => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compare_series_i64(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::gt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::gt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compare_series_decimal(
|
||||||
|
lhs.as_ref(),
|
||||||
|
val,
|
||||||
|
ChunkedArray::gt_eq,
|
||||||
|
&left.tag.span,
|
||||||
|
)),
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to compare this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
Operator::Contains => match &right.value {
|
||||||
|
UntaggedValue::Primitive(Primitive::String(val)) => {
|
||||||
|
Ok(contains_series_pat(lhs.as_ref(), val, &left.tag.span))
|
||||||
|
}
|
||||||
|
_ => Ok(UntaggedValue::Error(
|
||||||
|
ShellError::labeled_error_with_secondary(
|
||||||
|
"Operation unavailable",
|
||||||
|
"unable to perform this value to the series",
|
||||||
|
&right.tag.span,
|
||||||
|
"Only primary values are allowed",
|
||||||
|
&right.tag.span,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
_ => Ok(UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Incorrect datatype",
|
||||||
|
"unable to use this value for this operation",
|
||||||
|
&left.tag.span,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err((left.type_name(), right.type_name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>,
|
||||||
|
{
|
||||||
|
let casted = series.i64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(casted, *val);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_series_bigint<'r, F>(
|
||||||
|
series: &'r Series,
|
||||||
|
val: &BigInt,
|
||||||
|
f: F,
|
||||||
|
span: &Span,
|
||||||
|
) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>,
|
||||||
|
{
|
||||||
|
let casted = series.i64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(
|
||||||
|
casted,
|
||||||
|
val.to_i64()
|
||||||
|
.expect("Internal error: protocol did not use compatible decimal"),
|
||||||
|
);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_series_decimal<'r, F>(
|
||||||
|
series: &'r Series,
|
||||||
|
val: &BigDecimal,
|
||||||
|
f: F,
|
||||||
|
span: &Span,
|
||||||
|
) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Float64Type>, f64) -> ChunkedArray<Float64Type>,
|
||||||
|
{
|
||||||
|
let casted = series.f64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(
|
||||||
|
casted,
|
||||||
|
val.to_f64()
|
||||||
|
.expect("Internal error: protocol did not use compatible decimal"),
|
||||||
|
);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compare_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>,
|
||||||
|
{
|
||||||
|
let casted = series.i64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(casted, *val);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compare_series_bigint<'r, F>(
|
||||||
|
series: &'r Series,
|
||||||
|
val: &BigInt,
|
||||||
|
f: F,
|
||||||
|
span: &Span,
|
||||||
|
) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>,
|
||||||
|
{
|
||||||
|
let casted = series.i64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(
|
||||||
|
casted,
|
||||||
|
val.to_i64()
|
||||||
|
.expect("Internal error: protocol did not use compatible decimal"),
|
||||||
|
);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compare_series_decimal<'r, F>(
|
||||||
|
series: &'r Series,
|
||||||
|
val: &BigDecimal,
|
||||||
|
f: F,
|
||||||
|
span: &Span,
|
||||||
|
) -> UntaggedValue
|
||||||
|
where
|
||||||
|
F: Fn(&'r ChunkedArray<Float64Type>, i64) -> ChunkedArray<BooleanType>,
|
||||||
|
{
|
||||||
|
let casted = series.f64();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = f(
|
||||||
|
casted,
|
||||||
|
val.to_i64()
|
||||||
|
.expect("Internal error: protocol did not use compatible decimal"),
|
||||||
|
);
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contains_series_pat(series: &Series, pat: &str, span: &Span) -> UntaggedValue {
|
||||||
|
let casted = series.utf8();
|
||||||
|
match casted {
|
||||||
|
Ok(casted) => {
|
||||||
|
let res = casted.contains(pat);
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Ok(res) => {
|
||||||
|
let res = res.into_series();
|
||||||
|
NuSeries::series_to_untagged(res)
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Search error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
|
||||||
|
"Casting error",
|
||||||
|
format!("{}", e),
|
||||||
|
span,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,4 +7,7 @@ pub mod types;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod value;
|
pub mod value;
|
||||||
|
|
||||||
|
#[cfg(feature = "dataframe")]
|
||||||
|
pub mod dataframe;
|
||||||
|
|
||||||
pub use dict::TaggedListBuilder;
|
pub use dict::TaggedListBuilder;
|
||||||
|
|
|
@ -12,9 +12,6 @@ use num_bigint::BigInt;
|
||||||
use num_traits::{ToPrimitive, Zero};
|
use num_traits::{ToPrimitive, Zero};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[cfg(feature = "dataframe")]
|
|
||||||
use nu_protocol::dataframe::{NuSeries, PolarsData};
|
|
||||||
|
|
||||||
pub struct Date;
|
pub struct Date;
|
||||||
|
|
||||||
impl Date {
|
impl Date {
|
||||||
|
@ -494,51 +491,6 @@ pub fn compute_values(
|
||||||
}
|
}
|
||||||
_ => Err((left.type_name(), right.type_name())),
|
_ => Err((left.type_name(), right.type_name())),
|
||||||
},
|
},
|
||||||
#[cfg(feature = "dataframe")]
|
|
||||||
(
|
|
||||||
UntaggedValue::DataFrame(PolarsData::Series(lhs)),
|
|
||||||
UntaggedValue::DataFrame(PolarsData::Series(rhs)),
|
|
||||||
) => {
|
|
||||||
if lhs.as_ref().dtype() == rhs.as_ref().dtype() {
|
|
||||||
let result = match operator {
|
|
||||||
Operator::Plus => {
|
|
||||||
let mut res = lhs.as_ref() + rhs.as_ref();
|
|
||||||
let name = format!("sum_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
|
||||||
let res = res.rename(name.as_ref());
|
|
||||||
Ok(res.clone())
|
|
||||||
}
|
|
||||||
Operator::Minus => {
|
|
||||||
let mut res = lhs.as_ref() - rhs.as_ref();
|
|
||||||
let name = format!("sub_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
|
||||||
let res = res.rename(name.as_ref());
|
|
||||||
Ok(res.clone())
|
|
||||||
}
|
|
||||||
Operator::Multiply => {
|
|
||||||
let mut res = lhs.as_ref() * rhs.as_ref();
|
|
||||||
let name = format!("mul_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
|
||||||
let res = res.rename(name.as_ref());
|
|
||||||
Ok(res.clone())
|
|
||||||
}
|
|
||||||
Operator::Divide => {
|
|
||||||
let mut res = lhs.as_ref() / rhs.as_ref();
|
|
||||||
let name = format!("div_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
|
||||||
let res = res.rename(name.as_ref());
|
|
||||||
Ok(res.clone())
|
|
||||||
}
|
|
||||||
Operator::Modulo => {
|
|
||||||
let mut res = lhs.as_ref() % rhs.as_ref();
|
|
||||||
let name = format!("mod_{}_{}", lhs.as_ref().name(), rhs.as_ref().name());
|
|
||||||
let res = res.rename(name.as_ref());
|
|
||||||
Ok(res.clone())
|
|
||||||
}
|
|
||||||
_ => Err((left.type_name(), right.type_name())),
|
|
||||||
}?;
|
|
||||||
|
|
||||||
Ok(NuSeries::series_to_untagged(result))
|
|
||||||
} else {
|
|
||||||
Err((left.type_name(), right.type_name()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => Err((left.type_name(), right.type_name())),
|
_ => Err((left.type_name(), right.type_name())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,3 +65,4 @@ hamcrest2 = "0.3.0"
|
||||||
rustyline-support = []
|
rustyline-support = []
|
||||||
dirs = ["dirs-next"]
|
dirs = ["dirs-next"]
|
||||||
trash-support = ["trash"]
|
trash-support = ["trash"]
|
||||||
|
dataframe = ["nu-protocol/dataframe"]
|
||||||
|
|
|
@ -4,11 +4,29 @@ use nu_protocol::hir::Operator;
|
||||||
use nu_protocol::{Primitive, ShellTypeName, UntaggedValue, Value};
|
use nu_protocol::{Primitive, ShellTypeName, UntaggedValue, Value};
|
||||||
use std::ops::Not;
|
use std::ops::Not;
|
||||||
|
|
||||||
|
#[cfg(feature = "dataframe")]
|
||||||
|
use nu_data::dataframe::{compute_between_series, compute_series_single_value};
|
||||||
|
#[cfg(feature = "dataframe")]
|
||||||
|
use nu_protocol::dataframe::PolarsData;
|
||||||
|
|
||||||
pub fn apply_operator(
|
pub fn apply_operator(
|
||||||
op: Operator,
|
op: Operator,
|
||||||
left: &Value,
|
left: &Value,
|
||||||
right: &Value,
|
right: &Value,
|
||||||
) -> Result<UntaggedValue, (&'static str, &'static str)> {
|
) -> Result<UntaggedValue, (&'static str, &'static str)> {
|
||||||
|
#[cfg(feature = "dataframe")]
|
||||||
|
if let (
|
||||||
|
UntaggedValue::DataFrame(PolarsData::Series(_)),
|
||||||
|
UntaggedValue::DataFrame(PolarsData::Series(_)),
|
||||||
|
) = (&left.value, &right.value)
|
||||||
|
{
|
||||||
|
return compute_between_series(op, left, right);
|
||||||
|
} else if let (UntaggedValue::DataFrame(PolarsData::Series(_)), UntaggedValue::Primitive(_)) =
|
||||||
|
(&left.value, &right.value)
|
||||||
|
{
|
||||||
|
return compute_series_single_value(op, left, right);
|
||||||
|
}
|
||||||
|
|
||||||
match op {
|
match op {
|
||||||
Operator::Equal
|
Operator::Equal
|
||||||
| Operator::NotEqual
|
| Operator::NotEqual
|
||||||
|
|
|
@ -32,10 +32,10 @@ toml = "0.5.8"
|
||||||
|
|
||||||
[dependencies.polars]
|
[dependencies.polars]
|
||||||
git = "https://github.com/pola-rs/polars"
|
git = "https://github.com/pola-rs/polars"
|
||||||
rev = "a5f17b0a6e3e05ff6be789aa24a7cae54fd400dd"
|
rev = "9e1506cca9fb646fc55f949ab6345290c3d198a7"
|
||||||
version = "0.14.0"
|
version = "0.14.1"
|
||||||
optional = true
|
optional = true
|
||||||
features = ["serde"]
|
features = ["serde", "rows"]
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
dataframe = ["polars"]
|
dataframe = ["polars"]
|
||||||
|
|
|
@ -185,7 +185,7 @@ impl NuDataFrame {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_rows(&self, from_row: usize, to_row: usize) -> Result<Vec<Value>, ShellError> {
|
pub fn to_rows(&self, from_row: usize, to_row: usize) -> Result<Vec<Value>, ShellError> {
|
||||||
let df = &self.as_ref();
|
let df = self.as_ref();
|
||||||
let column_names = df.get_column_names();
|
let column_names = df.get_column_names();
|
||||||
|
|
||||||
let mut values: Vec<Value> = Vec::new();
|
let mut values: Vec<Value> = Vec::new();
|
||||||
|
|
Loading…
Reference in a new issue