Rolling and cumulative commands (#3960)

* rolling and cumulative operations

* update polars to 0.15.1

* change reference in function
This commit is contained in:
Fernando Herrera 2021-08-24 15:10:29 +01:00 committed by GitHub
parent 0abe753003
commit 02b2c55146
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 340 additions and 27 deletions

20
Cargo.lock generated
View file

@ -4488,9 +4488,9 @@ dependencies = [
[[package]] [[package]]
name = "polars" name = "polars"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "008543a5972e0cc94e08cfbf1ec7c5904dc09256f31ffc2b04ab2f24032be69c" checksum = "080a75a211a2b0feb01c85efbdb3906593150629679f2b1147ae80c98cde87ea"
dependencies = [ dependencies = [
"polars-core", "polars-core",
"polars-io", "polars-io",
@ -4499,9 +4499,9 @@ dependencies = [
[[package]] [[package]]
name = "polars-arrow" name = "polars-arrow"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0b1c85210b633bc690e01a2b88ca8528af9a50d4a75d0e313ddec89ad2ed4a" checksum = "5ff08a43b4d19710717e35d91b036bb636fdf7daa794c9ccc38e7ba95aca1135"
dependencies = [ dependencies = [
"arrow", "arrow",
"num 0.4.0", "num 0.4.0",
@ -4510,9 +4510,9 @@ dependencies = [
[[package]] [[package]]
name = "polars-core" name = "polars-core"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1547583d662e51c3cf1871296874acacb53edb7e074e6f00fd537cd81237a6f" checksum = "4929a06fa295e6f702a7d516a84f5cf65cfb72c675daf3cf99031a0d1ef6fe6f"
dependencies = [ dependencies = [
"ahash", "ahash",
"anyhow", "anyhow",
@ -4538,9 +4538,9 @@ dependencies = [
[[package]] [[package]]
name = "polars-io" name = "polars-io"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1048c30f9c85e12a2f2f65a2e4527c6cc0fc9ec1e7ff4ad98b2759a7e12ba699" checksum = "422d2405b409defef53de50d2eaf3ab77080c95505d191722c6236d9f3a022a7"
dependencies = [ dependencies = [
"ahash", "ahash",
"anyhow", "anyhow",
@ -4563,9 +4563,9 @@ dependencies = [
[[package]] [[package]]
name = "polars-lazy" name = "polars-lazy"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad98453f7bc530b5531c802aad6d4ea480d99a90a3a0318135aab9b501f562d0" checksum = "50a6302df2cf8f261823fefc032b5fc7c6f8d2eb5c55c9bf305d472d1f500376"
dependencies = [ dependencies = [
"ahash", "ahash",
"itertools", "itertools",

View file

@ -101,7 +101,7 @@ zip = { version="0.5.9", optional=true }
digest = "0.9.0" digest = "0.9.0"
[dependencies.polars] [dependencies.polars]
version = "0.15" version = "0.15.1"
optional = true optional = true
features = ["parquet", "json", "random", "pivot", "strings", "is_in", "temporal"] features = ["parquet", "json", "random", "pivot", "strings", "is_in", "temporal"]

View file

@ -75,6 +75,7 @@ pub use series::DataFrameArgTrue;
pub use series::DataFrameArgUnique; pub use series::DataFrameArgUnique;
pub use series::DataFrameConcatenate; pub use series::DataFrameConcatenate;
pub use series::DataFrameContains; pub use series::DataFrameContains;
pub use series::DataFrameCumulative;
pub use series::DataFrameGetDay; pub use series::DataFrameGetDay;
pub use series::DataFrameGetHour; pub use series::DataFrameGetHour;
pub use series::DataFrameGetMinute; pub use series::DataFrameGetMinute;
@ -95,6 +96,7 @@ pub use series::DataFrameNUnique;
pub use series::DataFrameNot; pub use series::DataFrameNot;
pub use series::DataFrameReplace; pub use series::DataFrameReplace;
pub use series::DataFrameReplaceAll; pub use series::DataFrameReplaceAll;
pub use series::DataFrameRolling;
pub use series::DataFrameSeriesRename; pub use series::DataFrameSeriesRename;
pub use series::DataFrameSet; pub use series::DataFrameSet;
pub use series::DataFrameSetWithIdx; pub use series::DataFrameSetWithIdx;

View file

@ -0,0 +1,127 @@
use crate::prelude::*;
use nu_engine::WholeStreamCommand;
use nu_errors::ShellError;
use nu_protocol::{
dataframe::{Column, NuDataFrame},
Signature, SyntaxShape, UntaggedValue,
};
use nu_source::Tagged;
use polars::prelude::DataType;
enum CumType {
Min,
Max,
Sum,
}
impl CumType {
fn from_str(roll_type: &str, span: &Span) -> Result<Self, ShellError> {
match roll_type {
"min" => Ok(Self::Min),
"max" => Ok(Self::Max),
"sum" => Ok(Self::Sum),
_ => Err(ShellError::labeled_error_with_secondary(
"Wrong operation",
"Operation not valid for cumulative",
span,
"Perhaps you want to use: max, min, sum",
span,
)),
}
}
fn to_str(&self) -> &'static str {
match self {
CumType::Min => "cum_min",
CumType::Max => "cum_max",
CumType::Sum => "cum_sum",
}
}
}
pub struct DataFrame;
impl WholeStreamCommand for DataFrame {
fn name(&self) -> &str {
"dataframe cum"
}
fn usage(&self) -> &str {
"[Series] Cumulative calculation for a series"
}
fn signature(&self) -> Signature {
Signature::build("dataframe cum")
.required("type", SyntaxShape::String, "rolling operation")
.switch("reverse", "Reverse cumulative calculation", Some('r'))
}
fn run(&self, args: CommandArgs) -> Result<OutputStream, ShellError> {
command(args)
}
fn examples(&self) -> Vec<Example> {
vec![Example {
description: "Cumulative sum for a series",
example: "[1 2 3 4 5] | dataframe to-df | dataframe cum sum",
result: Some(vec![NuDataFrame::try_from_columns(
vec![Column::new(
"0_cum_sum".to_string(),
vec![
UntaggedValue::int(1).into(),
UntaggedValue::int(3).into(),
UntaggedValue::int(6).into(),
UntaggedValue::int(10).into(),
UntaggedValue::int(15).into(),
],
)],
&Span::default(),
)
.expect("simple df for test should not fail")
.into_value(Tag::default())]),
}]
}
}
fn command(mut args: CommandArgs) -> Result<OutputStream, ShellError> {
let tag = args.call_info.name_tag.clone();
let cum_type: Tagged<String> = args.req(0)?;
let reverse = args.has_flag("reverse");
let (df, df_tag) = NuDataFrame::try_from_stream(&mut args.input, &tag.span)?;
let series = df.as_series(&df_tag.span)?;
if let DataType::Object(_) = series.dtype() {
return Err(ShellError::labeled_error(
"Found object series",
"Series of type object cannot be used for cumulative operation",
&df_tag.span,
));
}
let cum_type = CumType::from_str(cum_type.item.as_str(), &cum_type.tag.span)?;
let mut res = match cum_type {
CumType::Max => series.cum_max(reverse),
CumType::Min => series.cum_min(reverse),
CumType::Sum => series.cum_sum(reverse),
};
let name = format!("{}_{}", series.name(), cum_type.to_str());
res.rename(&name);
let df = NuDataFrame::try_from_series(vec![res], &tag.span)?;
Ok(OutputStream::one(df.into_value(df_tag)))
}
#[cfg(test)]
mod tests {
use super::DataFrame;
use super::ShellError;
#[test]
fn examples_work_as_expected() -> Result<(), ShellError> {
use crate::examples::test_dataframe as test_examples;
test_examples(DataFrame {})
}
}

View file

@ -7,6 +7,7 @@ pub mod arg_true;
pub mod arg_unique; pub mod arg_unique;
pub mod concatenate; pub mod concatenate;
pub mod contains; pub mod contains;
pub mod cumulative;
pub mod get_day; pub mod get_day;
pub mod get_hour; pub mod get_hour;
pub mod get_minute; pub mod get_minute;
@ -28,6 +29,7 @@ pub mod not;
pub mod rename; pub mod rename;
pub mod replace; pub mod replace;
pub mod replace_all; pub mod replace_all;
pub mod rolling;
pub mod set; pub mod set;
pub mod set_with_idx; pub mod set_with_idx;
pub mod shift; pub mod shift;
@ -48,6 +50,7 @@ pub use arg_true::DataFrame as DataFrameArgTrue;
pub use arg_unique::DataFrame as DataFrameArgUnique; pub use arg_unique::DataFrame as DataFrameArgUnique;
pub use concatenate::DataFrame as DataFrameConcatenate; pub use concatenate::DataFrame as DataFrameConcatenate;
pub use contains::DataFrame as DataFrameContains; pub use contains::DataFrame as DataFrameContains;
pub use cumulative::DataFrame as DataFrameCumulative;
pub use get_day::DataFrame as DataFrameGetDay; pub use get_day::DataFrame as DataFrameGetDay;
pub use get_hour::DataFrame as DataFrameGetHour; pub use get_hour::DataFrame as DataFrameGetHour;
pub use get_minute::DataFrame as DataFrameGetMinute; pub use get_minute::DataFrame as DataFrameGetMinute;
@ -69,6 +72,7 @@ pub use not::DataFrame as DataFrameNot;
pub use rename::DataFrame as DataFrameSeriesRename; pub use rename::DataFrame as DataFrameSeriesRename;
pub use replace::DataFrame as DataFrameReplace; pub use replace::DataFrame as DataFrameReplace;
pub use replace_all::DataFrame as DataFrameReplaceAll; pub use replace_all::DataFrame as DataFrameReplaceAll;
pub use rolling::DataFrame as DataFrameRolling;
pub use set::DataFrame as DataFrameSet; pub use set::DataFrame as DataFrameSet;
pub use set_with_idx::DataFrame as DataFrameSetWithIdx; pub use set_with_idx::DataFrame as DataFrameSetWithIdx;
pub use shift::DataFrame as DataFrameShift; pub use shift::DataFrame as DataFrameShift;

View file

@ -0,0 +1,176 @@
use crate::{commands::dataframe::utils::parse_polars_error, prelude::*};
use nu_engine::WholeStreamCommand;
use nu_errors::ShellError;
use nu_protocol::{
dataframe::{Column, NuDataFrame},
Signature, SyntaxShape, UntaggedValue,
};
use nu_source::Tagged;
use polars::prelude::DataType;
enum RollType {
Min,
Max,
Sum,
Mean,
}
impl RollType {
fn from_str(roll_type: &str, span: &Span) -> Result<Self, ShellError> {
match roll_type {
"min" => Ok(Self::Min),
"max" => Ok(Self::Max),
"sum" => Ok(Self::Sum),
"mean" => Ok(Self::Mean),
_ => Err(ShellError::labeled_error_with_secondary(
"Wrong operation",
"Operation not valid for rolling",
span,
"Perhaps you want to use: max, min, sum, mean",
span,
)),
}
}
fn to_str(&self) -> &'static str {
match self {
RollType::Min => "rolling_min",
RollType::Max => "rolling_max",
RollType::Sum => "rolling_sum",
RollType::Mean => "rolling_mean",
}
}
}
pub struct DataFrame;
impl WholeStreamCommand for DataFrame {
fn name(&self) -> &str {
"dataframe rolling"
}
fn usage(&self) -> &str {
"[Series] Rolling calculation for a series"
}
fn signature(&self) -> Signature {
Signature::build("dataframe rolling")
.required("type", SyntaxShape::String, "rolling operation")
.required("window", SyntaxShape::Int, "Window size for rolling")
.switch("ignore_nulls", "Ignore nulls in column", Some('i'))
}
fn run(&self, args: CommandArgs) -> Result<OutputStream, ShellError> {
command(args)
}
fn examples(&self) -> Vec<Example> {
vec![
Example {
description: "Rolling sum for a series",
example:
"[1 2 3 4 5] | dataframe to-df | dataframe rolling sum 2 | dataframe drop-nulls",
result: Some(vec![NuDataFrame::try_from_columns(
vec![Column::new(
"0_rolling_sum".to_string(),
vec![
UntaggedValue::int(3).into(),
UntaggedValue::int(5).into(),
UntaggedValue::int(7).into(),
UntaggedValue::int(9).into(),
],
)],
&Span::default(),
)
.expect("simple df for test should not fail")
.into_value(Tag::default())]),
},
Example {
description: "Rolling max for a series",
example:
"[1 2 3 4 5] | dataframe to-df | dataframe rolling max 2 | dataframe drop-nulls",
result: Some(vec![NuDataFrame::try_from_columns(
vec![Column::new(
"0_rolling_max".to_string(),
vec![
UntaggedValue::int(2).into(),
UntaggedValue::int(3).into(),
UntaggedValue::int(4).into(),
UntaggedValue::int(5).into(),
],
)],
&Span::default(),
)
.expect("simple df for test should not fail")
.into_value(Tag::default())]),
},
]
}
}
fn command(mut args: CommandArgs) -> Result<OutputStream, ShellError> {
let tag = args.call_info.name_tag.clone();
let roll_type: Tagged<String> = args.req(0)?;
let window_size: Tagged<i64> = args.req(1)?;
let ignore_nulls = args.has_flag("ignore_nulls");
let (df, df_tag) = NuDataFrame::try_from_stream(&mut args.input, &tag.span)?;
let series = df.as_series(&df_tag.span)?;
if let DataType::Object(_) = series.dtype() {
return Err(ShellError::labeled_error(
"Found object series",
"Series of type object cannot be used for rolling operation",
&df_tag.span,
));
}
let roll_type = RollType::from_str(roll_type.item.as_str(), &roll_type.tag.span)?;
let res = match roll_type {
RollType::Max => series.rolling_max(
window_size.item as u32,
None,
ignore_nulls,
window_size.item as u32,
),
RollType::Min => series.rolling_min(
window_size.item as u32,
None,
ignore_nulls,
window_size.item as u32,
),
RollType::Sum => series.rolling_sum(
window_size.item as u32,
None,
ignore_nulls,
window_size.item as u32,
),
RollType::Mean => series.rolling_mean(
window_size.item as u32,
None,
ignore_nulls,
window_size.item as u32,
),
};
let mut res = res.map_err(|e| parse_polars_error::<&str>(&e, &df_tag.span, None))?;
let name = format!("{}_{}", series.name(), roll_type.to_str());
res.rename(&name);
let df = NuDataFrame::try_from_series(vec![res], &tag.span)?;
Ok(OutputStream::one(df.into_value(df_tag)))
}
#[cfg(test)]
mod tests {
use super::DataFrame;
use super::ShellError;
#[test]
fn examples_work_as_expected() -> Result<(), ShellError> {
use crate::examples::test_dataframe as test_examples;
test_examples(DataFrame {})
}
}

View file

@ -27,19 +27,20 @@ pub use core_commands::*;
pub use dataframe::{ pub use dataframe::{
DataFrame, DataFrameAggregate, DataFrameAllFalse, DataFrameAllTrue, DataFrameAppend, DataFrame, DataFrameAggregate, DataFrameAllFalse, DataFrameAllTrue, DataFrameAppend,
DataFrameArgMax, DataFrameArgMin, DataFrameArgSort, DataFrameArgTrue, DataFrameArgUnique, DataFrameArgMax, DataFrameArgMin, DataFrameArgSort, DataFrameArgTrue, DataFrameArgUnique,
DataFrameColumn, DataFrameConcatenate, DataFrameContains, DataFrameDTypes, DataFrameDescribe, DataFrameColumn, DataFrameConcatenate, DataFrameContains, DataFrameCumulative, DataFrameDTypes,
DataFrameDrop, DataFrameDropDuplicates, DataFrameDropNulls, DataFrameDummies, DataFrameFilter, DataFrameDescribe, DataFrameDrop, DataFrameDropDuplicates, DataFrameDropNulls,
DataFrameFirst, DataFrameGet, DataFrameGetDay, DataFrameGetHour, DataFrameGetMinute, DataFrameDummies, DataFrameFilter, DataFrameFirst, DataFrameGet, DataFrameGetDay,
DataFrameGetMonth, DataFrameGetNanoSecond, DataFrameGetOrdinal, DataFrameGetSecond, DataFrameGetHour, DataFrameGetMinute, DataFrameGetMonth, DataFrameGetNanoSecond,
DataFrameGetWeek, DataFrameGetWeekDay, DataFrameGetYear, DataFrameGroupBy, DataFrameGetOrdinal, DataFrameGetSecond, DataFrameGetWeek, DataFrameGetWeekDay,
DataFrameIsDuplicated, DataFrameIsIn, DataFrameIsNotNull, DataFrameIsNull, DataFrameIsUnique, DataFrameGetYear, DataFrameGroupBy, DataFrameIsDuplicated, DataFrameIsIn, DataFrameIsNotNull,
DataFrameJoin, DataFrameLast, DataFrameList, DataFrameMelt, DataFrameNNull, DataFrameNUnique, DataFrameIsNull, DataFrameIsUnique, DataFrameJoin, DataFrameLast, DataFrameList, DataFrameMelt,
DataFrameNot, DataFrameOpen, DataFramePivot, DataFrameReplace, DataFrameReplaceAll, DataFrameNNull, DataFrameNUnique, DataFrameNot, DataFrameOpen, DataFramePivot,
DataFrameSample, DataFrameSelect, DataFrameSeriesRename, DataFrameSet, DataFrameSetWithIdx, DataFrameReplace, DataFrameReplaceAll, DataFrameRolling, DataFrameSample, DataFrameSelect,
DataFrameShape, DataFrameShift, DataFrameShow, DataFrameSlice, DataFrameSort, DataFrameSeriesRename, DataFrameSet, DataFrameSetWithIdx, DataFrameShape, DataFrameShift,
DataFrameStrFTime, DataFrameStringLengths, DataFrameStringSlice, DataFrameTake, DataFrameToCsv, DataFrameShow, DataFrameSlice, DataFrameSort, DataFrameStrFTime, DataFrameStringLengths,
DataFrameToDF, DataFrameToLowercase, DataFrameToParquet, DataFrameToUppercase, DataFrameUnique, DataFrameStringSlice, DataFrameTake, DataFrameToCsv, DataFrameToDF, DataFrameToLowercase,
DataFrameValueCounts, DataFrameWhere, DataFrameWithColumn, DataFrameToParquet, DataFrameToUppercase, DataFrameUnique, DataFrameValueCounts,
DataFrameWhere, DataFrameWithColumn,
}; };
pub use env::*; pub use env::*;
pub use filesystem::*; pub use filesystem::*;

View file

@ -345,6 +345,8 @@ pub fn create_default_context(interactive: bool) -> Result<EvaluationContext, Bo
whole_stream_command(DataFrameGetNanoSecond), whole_stream_command(DataFrameGetNanoSecond),
whole_stream_command(DataFrameStrFTime), whole_stream_command(DataFrameStrFTime),
whole_stream_command(DataFrameDescribe), whole_stream_command(DataFrameDescribe),
whole_stream_command(DataFrameRolling),
whole_stream_command(DataFrameCumulative),
]); ]);
#[cfg(feature = "clipboard-cli")] #[cfg(feature = "clipboard-cli")]

View file

@ -16,8 +16,8 @@ use nu_source::AnchorLocation;
#[cfg(feature = "dataframe")] #[cfg(feature = "dataframe")]
use crate::commands::{ use crate::commands::{
DataFrameGroupBy, DataFrameIsNull, DataFrameShift, DataFrameToDF, DataFrameWithColumn, DataFrameDropNulls, DataFrameGroupBy, DataFrameIsNull, DataFrameShift, DataFrameToDF,
StrToDatetime, DataFrameWithColumn, StrToDatetime,
}; };
use crate::commands::{ use crate::commands::{
@ -173,6 +173,7 @@ pub fn test_dataframe(cmd: impl WholeStreamCommand + 'static) -> Result<(), Shel
whole_stream_command(DataFrameIsNull), whole_stream_command(DataFrameIsNull),
whole_stream_command(DataFrameGroupBy), whole_stream_command(DataFrameGroupBy),
whole_stream_command(DataFrameWithColumn), whole_stream_command(DataFrameWithColumn),
whole_stream_command(DataFrameDropNulls),
// Base commands for context // Base commands for context
whole_stream_command(Math), whole_stream_command(Math),
whole_stream_command(MathMode {}), whole_stream_command(MathMode {}),

View file

@ -32,7 +32,7 @@ serde_yaml = "0.8.16"
toml = "0.5.8" toml = "0.5.8"
[dependencies.polars] [dependencies.polars]
version = "0.15" version = "0.15.1"
optional = true optional = true
features = ["default", "serde", "rows", "strings", "checked_arithmetic", "object", "dtype-duration-ns"] features = ["default", "serde", "rows", "strings", "checked_arithmetic", "object", "dtype-duration-ns"]