diff --git a/Cargo.lock b/Cargo.lock index 5f1fd86f6a..0fb71cc101 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4488,9 +4488,9 @@ dependencies = [ [[package]] name = "polars" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008543a5972e0cc94e08cfbf1ec7c5904dc09256f31ffc2b04ab2f24032be69c" +checksum = "080a75a211a2b0feb01c85efbdb3906593150629679f2b1147ae80c98cde87ea" dependencies = [ "polars-core", "polars-io", @@ -4499,9 +4499,9 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0b1c85210b633bc690e01a2b88ca8528af9a50d4a75d0e313ddec89ad2ed4a" +checksum = "5ff08a43b4d19710717e35d91b036bb636fdf7daa794c9ccc38e7ba95aca1135" dependencies = [ "arrow", "num 0.4.0", @@ -4510,9 +4510,9 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1547583d662e51c3cf1871296874acacb53edb7e074e6f00fd537cd81237a6f" +checksum = "4929a06fa295e6f702a7d516a84f5cf65cfb72c675daf3cf99031a0d1ef6fe6f" dependencies = [ "ahash", "anyhow", @@ -4538,9 +4538,9 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1048c30f9c85e12a2f2f65a2e4527c6cc0fc9ec1e7ff4ad98b2759a7e12ba699" +checksum = "422d2405b409defef53de50d2eaf3ab77080c95505d191722c6236d9f3a022a7" dependencies = [ "ahash", "anyhow", @@ -4563,9 +4563,9 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad98453f7bc530b5531c802aad6d4ea480d99a90a3a0318135aab9b501f562d0" +checksum = "50a6302df2cf8f261823fefc032b5fc7c6f8d2eb5c55c9bf305d472d1f500376" dependencies = [ "ahash", "itertools", diff --git a/crates/nu-command/Cargo.toml b/crates/nu-command/Cargo.toml index beec12721a..ec36d115d9 100644 --- a/crates/nu-command/Cargo.toml +++ b/crates/nu-command/Cargo.toml @@ -101,7 +101,7 @@ zip = { version="0.5.9", optional=true } digest = "0.9.0" [dependencies.polars] -version = "0.15" +version = "0.15.1" optional = true features = ["parquet", "json", "random", "pivot", "strings", "is_in", "temporal"] diff --git a/crates/nu-command/src/commands/dataframe/mod.rs b/crates/nu-command/src/commands/dataframe/mod.rs index 07d53f271a..227f22bb0a 100644 --- a/crates/nu-command/src/commands/dataframe/mod.rs +++ b/crates/nu-command/src/commands/dataframe/mod.rs @@ -75,6 +75,7 @@ pub use series::DataFrameArgTrue; pub use series::DataFrameArgUnique; pub use series::DataFrameConcatenate; pub use series::DataFrameContains; +pub use series::DataFrameCumulative; pub use series::DataFrameGetDay; pub use series::DataFrameGetHour; pub use series::DataFrameGetMinute; @@ -95,6 +96,7 @@ pub use series::DataFrameNUnique; pub use series::DataFrameNot; pub use series::DataFrameReplace; pub use series::DataFrameReplaceAll; +pub use series::DataFrameRolling; pub use series::DataFrameSeriesRename; pub use series::DataFrameSet; pub use series::DataFrameSetWithIdx; diff --git a/crates/nu-command/src/commands/dataframe/series/cumulative.rs b/crates/nu-command/src/commands/dataframe/series/cumulative.rs new file mode 100644 index 0000000000..7846b41bfe --- /dev/null +++ b/crates/nu-command/src/commands/dataframe/series/cumulative.rs @@ -0,0 +1,127 @@ +use crate::prelude::*; +use nu_engine::WholeStreamCommand; +use nu_errors::ShellError; +use nu_protocol::{ + dataframe::{Column, NuDataFrame}, + Signature, SyntaxShape, UntaggedValue, +}; +use nu_source::Tagged; +use polars::prelude::DataType; + +enum CumType { + Min, + Max, + Sum, +} + +impl CumType { + fn from_str(roll_type: &str, span: &Span) -> Result { + match roll_type { + "min" => Ok(Self::Min), + "max" => Ok(Self::Max), + "sum" => Ok(Self::Sum), + _ => Err(ShellError::labeled_error_with_secondary( + "Wrong operation", + "Operation not valid for cumulative", + span, + "Perhaps you want to use: max, min, sum", + span, + )), + } + } + + fn to_str(&self) -> &'static str { + match self { + CumType::Min => "cum_min", + CumType::Max => "cum_max", + CumType::Sum => "cum_sum", + } + } +} + +pub struct DataFrame; + +impl WholeStreamCommand for DataFrame { + fn name(&self) -> &str { + "dataframe cum" + } + + fn usage(&self) -> &str { + "[Series] Cumulative calculation for a series" + } + + fn signature(&self) -> Signature { + Signature::build("dataframe cum") + .required("type", SyntaxShape::String, "rolling operation") + .switch("reverse", "Reverse cumulative calculation", Some('r')) + } + + fn run(&self, args: CommandArgs) -> Result { + command(args) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Cumulative sum for a series", + example: "[1 2 3 4 5] | dataframe to-df | dataframe cum sum", + result: Some(vec![NuDataFrame::try_from_columns( + vec![Column::new( + "0_cum_sum".to_string(), + vec![ + UntaggedValue::int(1).into(), + UntaggedValue::int(3).into(), + UntaggedValue::int(6).into(), + UntaggedValue::int(10).into(), + UntaggedValue::int(15).into(), + ], + )], + &Span::default(), + ) + .expect("simple df for test should not fail") + .into_value(Tag::default())]), + }] + } +} + +fn command(mut args: CommandArgs) -> Result { + let tag = args.call_info.name_tag.clone(); + let cum_type: Tagged = args.req(0)?; + let reverse = args.has_flag("reverse"); + + let (df, df_tag) = NuDataFrame::try_from_stream(&mut args.input, &tag.span)?; + let series = df.as_series(&df_tag.span)?; + + if let DataType::Object(_) = series.dtype() { + return Err(ShellError::labeled_error( + "Found object series", + "Series of type object cannot be used for cumulative operation", + &df_tag.span, + )); + } + + let cum_type = CumType::from_str(cum_type.item.as_str(), &cum_type.tag.span)?; + let mut res = match cum_type { + CumType::Max => series.cum_max(reverse), + CumType::Min => series.cum_min(reverse), + CumType::Sum => series.cum_sum(reverse), + }; + + let name = format!("{}_{}", series.name(), cum_type.to_str()); + res.rename(&name); + + let df = NuDataFrame::try_from_series(vec![res], &tag.span)?; + Ok(OutputStream::one(df.into_value(df_tag))) +} + +#[cfg(test)] +mod tests { + use super::DataFrame; + use super::ShellError; + + #[test] + fn examples_work_as_expected() -> Result<(), ShellError> { + use crate::examples::test_dataframe as test_examples; + + test_examples(DataFrame {}) + } +} diff --git a/crates/nu-command/src/commands/dataframe/series/mod.rs b/crates/nu-command/src/commands/dataframe/series/mod.rs index 519fa1d5b9..a51930f4cf 100644 --- a/crates/nu-command/src/commands/dataframe/series/mod.rs +++ b/crates/nu-command/src/commands/dataframe/series/mod.rs @@ -7,6 +7,7 @@ pub mod arg_true; pub mod arg_unique; pub mod concatenate; pub mod contains; +pub mod cumulative; pub mod get_day; pub mod get_hour; pub mod get_minute; @@ -28,6 +29,7 @@ pub mod not; pub mod rename; pub mod replace; pub mod replace_all; +pub mod rolling; pub mod set; pub mod set_with_idx; pub mod shift; @@ -48,6 +50,7 @@ pub use arg_true::DataFrame as DataFrameArgTrue; pub use arg_unique::DataFrame as DataFrameArgUnique; pub use concatenate::DataFrame as DataFrameConcatenate; pub use contains::DataFrame as DataFrameContains; +pub use cumulative::DataFrame as DataFrameCumulative; pub use get_day::DataFrame as DataFrameGetDay; pub use get_hour::DataFrame as DataFrameGetHour; pub use get_minute::DataFrame as DataFrameGetMinute; @@ -69,6 +72,7 @@ pub use not::DataFrame as DataFrameNot; pub use rename::DataFrame as DataFrameSeriesRename; pub use replace::DataFrame as DataFrameReplace; pub use replace_all::DataFrame as DataFrameReplaceAll; +pub use rolling::DataFrame as DataFrameRolling; pub use set::DataFrame as DataFrameSet; pub use set_with_idx::DataFrame as DataFrameSetWithIdx; pub use shift::DataFrame as DataFrameShift; diff --git a/crates/nu-command/src/commands/dataframe/series/rolling.rs b/crates/nu-command/src/commands/dataframe/series/rolling.rs new file mode 100644 index 0000000000..7a60ddd53e --- /dev/null +++ b/crates/nu-command/src/commands/dataframe/series/rolling.rs @@ -0,0 +1,176 @@ +use crate::{commands::dataframe::utils::parse_polars_error, prelude::*}; +use nu_engine::WholeStreamCommand; +use nu_errors::ShellError; +use nu_protocol::{ + dataframe::{Column, NuDataFrame}, + Signature, SyntaxShape, UntaggedValue, +}; +use nu_source::Tagged; +use polars::prelude::DataType; + +enum RollType { + Min, + Max, + Sum, + Mean, +} + +impl RollType { + fn from_str(roll_type: &str, span: &Span) -> Result { + match roll_type { + "min" => Ok(Self::Min), + "max" => Ok(Self::Max), + "sum" => Ok(Self::Sum), + "mean" => Ok(Self::Mean), + _ => Err(ShellError::labeled_error_with_secondary( + "Wrong operation", + "Operation not valid for rolling", + span, + "Perhaps you want to use: max, min, sum, mean", + span, + )), + } + } + + fn to_str(&self) -> &'static str { + match self { + RollType::Min => "rolling_min", + RollType::Max => "rolling_max", + RollType::Sum => "rolling_sum", + RollType::Mean => "rolling_mean", + } + } +} + +pub struct DataFrame; + +impl WholeStreamCommand for DataFrame { + fn name(&self) -> &str { + "dataframe rolling" + } + + fn usage(&self) -> &str { + "[Series] Rolling calculation for a series" + } + + fn signature(&self) -> Signature { + Signature::build("dataframe rolling") + .required("type", SyntaxShape::String, "rolling operation") + .required("window", SyntaxShape::Int, "Window size for rolling") + .switch("ignore_nulls", "Ignore nulls in column", Some('i')) + } + + fn run(&self, args: CommandArgs) -> Result { + command(args) + } + + fn examples(&self) -> Vec { + vec![ + Example { + description: "Rolling sum for a series", + example: + "[1 2 3 4 5] | dataframe to-df | dataframe rolling sum 2 | dataframe drop-nulls", + result: Some(vec![NuDataFrame::try_from_columns( + vec![Column::new( + "0_rolling_sum".to_string(), + vec![ + UntaggedValue::int(3).into(), + UntaggedValue::int(5).into(), + UntaggedValue::int(7).into(), + UntaggedValue::int(9).into(), + ], + )], + &Span::default(), + ) + .expect("simple df for test should not fail") + .into_value(Tag::default())]), + }, + Example { + description: "Rolling max for a series", + example: + "[1 2 3 4 5] | dataframe to-df | dataframe rolling max 2 | dataframe drop-nulls", + result: Some(vec![NuDataFrame::try_from_columns( + vec![Column::new( + "0_rolling_max".to_string(), + vec![ + UntaggedValue::int(2).into(), + UntaggedValue::int(3).into(), + UntaggedValue::int(4).into(), + UntaggedValue::int(5).into(), + ], + )], + &Span::default(), + ) + .expect("simple df for test should not fail") + .into_value(Tag::default())]), + }, + ] + } +} + +fn command(mut args: CommandArgs) -> Result { + let tag = args.call_info.name_tag.clone(); + let roll_type: Tagged = args.req(0)?; + let window_size: Tagged = args.req(1)?; + let ignore_nulls = args.has_flag("ignore_nulls"); + + let (df, df_tag) = NuDataFrame::try_from_stream(&mut args.input, &tag.span)?; + let series = df.as_series(&df_tag.span)?; + + if let DataType::Object(_) = series.dtype() { + return Err(ShellError::labeled_error( + "Found object series", + "Series of type object cannot be used for rolling operation", + &df_tag.span, + )); + } + + let roll_type = RollType::from_str(roll_type.item.as_str(), &roll_type.tag.span)?; + let res = match roll_type { + RollType::Max => series.rolling_max( + window_size.item as u32, + None, + ignore_nulls, + window_size.item as u32, + ), + RollType::Min => series.rolling_min( + window_size.item as u32, + None, + ignore_nulls, + window_size.item as u32, + ), + RollType::Sum => series.rolling_sum( + window_size.item as u32, + None, + ignore_nulls, + window_size.item as u32, + ), + RollType::Mean => series.rolling_mean( + window_size.item as u32, + None, + ignore_nulls, + window_size.item as u32, + ), + }; + + let mut res = res.map_err(|e| parse_polars_error::<&str>(&e, &df_tag.span, None))?; + + let name = format!("{}_{}", series.name(), roll_type.to_str()); + res.rename(&name); + + let df = NuDataFrame::try_from_series(vec![res], &tag.span)?; + Ok(OutputStream::one(df.into_value(df_tag))) +} + +#[cfg(test)] +mod tests { + use super::DataFrame; + use super::ShellError; + + #[test] + fn examples_work_as_expected() -> Result<(), ShellError> { + use crate::examples::test_dataframe as test_examples; + + test_examples(DataFrame {}) + } +} diff --git a/crates/nu-command/src/commands/mod.rs b/crates/nu-command/src/commands/mod.rs index 7551832178..b4d3261588 100644 --- a/crates/nu-command/src/commands/mod.rs +++ b/crates/nu-command/src/commands/mod.rs @@ -27,19 +27,20 @@ pub use core_commands::*; pub use dataframe::{ DataFrame, DataFrameAggregate, DataFrameAllFalse, DataFrameAllTrue, DataFrameAppend, DataFrameArgMax, DataFrameArgMin, DataFrameArgSort, DataFrameArgTrue, DataFrameArgUnique, - DataFrameColumn, DataFrameConcatenate, DataFrameContains, DataFrameDTypes, DataFrameDescribe, - DataFrameDrop, DataFrameDropDuplicates, DataFrameDropNulls, DataFrameDummies, DataFrameFilter, - DataFrameFirst, DataFrameGet, DataFrameGetDay, DataFrameGetHour, DataFrameGetMinute, - DataFrameGetMonth, DataFrameGetNanoSecond, DataFrameGetOrdinal, DataFrameGetSecond, - DataFrameGetWeek, DataFrameGetWeekDay, DataFrameGetYear, DataFrameGroupBy, - DataFrameIsDuplicated, DataFrameIsIn, DataFrameIsNotNull, DataFrameIsNull, DataFrameIsUnique, - DataFrameJoin, DataFrameLast, DataFrameList, DataFrameMelt, DataFrameNNull, DataFrameNUnique, - DataFrameNot, DataFrameOpen, DataFramePivot, DataFrameReplace, DataFrameReplaceAll, - DataFrameSample, DataFrameSelect, DataFrameSeriesRename, DataFrameSet, DataFrameSetWithIdx, - DataFrameShape, DataFrameShift, DataFrameShow, DataFrameSlice, DataFrameSort, - DataFrameStrFTime, DataFrameStringLengths, DataFrameStringSlice, DataFrameTake, DataFrameToCsv, - DataFrameToDF, DataFrameToLowercase, DataFrameToParquet, DataFrameToUppercase, DataFrameUnique, - DataFrameValueCounts, DataFrameWhere, DataFrameWithColumn, + DataFrameColumn, DataFrameConcatenate, DataFrameContains, DataFrameCumulative, DataFrameDTypes, + DataFrameDescribe, DataFrameDrop, DataFrameDropDuplicates, DataFrameDropNulls, + DataFrameDummies, DataFrameFilter, DataFrameFirst, DataFrameGet, DataFrameGetDay, + DataFrameGetHour, DataFrameGetMinute, DataFrameGetMonth, DataFrameGetNanoSecond, + DataFrameGetOrdinal, DataFrameGetSecond, DataFrameGetWeek, DataFrameGetWeekDay, + DataFrameGetYear, DataFrameGroupBy, DataFrameIsDuplicated, DataFrameIsIn, DataFrameIsNotNull, + DataFrameIsNull, DataFrameIsUnique, DataFrameJoin, DataFrameLast, DataFrameList, DataFrameMelt, + DataFrameNNull, DataFrameNUnique, DataFrameNot, DataFrameOpen, DataFramePivot, + DataFrameReplace, DataFrameReplaceAll, DataFrameRolling, DataFrameSample, DataFrameSelect, + DataFrameSeriesRename, DataFrameSet, DataFrameSetWithIdx, DataFrameShape, DataFrameShift, + DataFrameShow, DataFrameSlice, DataFrameSort, DataFrameStrFTime, DataFrameStringLengths, + DataFrameStringSlice, DataFrameTake, DataFrameToCsv, DataFrameToDF, DataFrameToLowercase, + DataFrameToParquet, DataFrameToUppercase, DataFrameUnique, DataFrameValueCounts, + DataFrameWhere, DataFrameWithColumn, }; pub use env::*; pub use filesystem::*; diff --git a/crates/nu-command/src/default_context.rs b/crates/nu-command/src/default_context.rs index b6c7ff38c2..dcf1c5398c 100644 --- a/crates/nu-command/src/default_context.rs +++ b/crates/nu-command/src/default_context.rs @@ -345,6 +345,8 @@ pub fn create_default_context(interactive: bool) -> Result Result<(), Shel whole_stream_command(DataFrameIsNull), whole_stream_command(DataFrameGroupBy), whole_stream_command(DataFrameWithColumn), + whole_stream_command(DataFrameDropNulls), // Base commands for context whole_stream_command(Math), whole_stream_command(MathMode {}), diff --git a/crates/nu-protocol/Cargo.toml b/crates/nu-protocol/Cargo.toml index 70b24f04f6..8f8ce4db01 100644 --- a/crates/nu-protocol/Cargo.toml +++ b/crates/nu-protocol/Cargo.toml @@ -32,7 +32,7 @@ serde_yaml = "0.8.16" toml = "0.5.8" [dependencies.polars] -version = "0.15" +version = "0.15.1" optional = true features = ["default", "serde", "rows", "strings", "checked_arithmetic", "object", "dtype-duration-ns"]