diff --git a/crates/nu-command/src/filters/group_by.rs b/crates/nu-command/src/filters/group_by.rs index 74a19c38e3..9396de4781 100644 --- a/crates/nu-command/src/filters/group_by.rs +++ b/crates/nu-command/src/filters/group_by.rs @@ -1,6 +1,6 @@ use indexmap::IndexMap; use nu_engine::{command_prelude::*, ClosureEval}; -use nu_protocol::engine::Closure; +use nu_protocol::{engine::Closure, IntoValue}; #[derive(Clone)] pub struct GroupBy; @@ -22,7 +22,7 @@ impl Command for GroupBy { "Return a table with \"groups\" and \"items\" columns", None, ) - .optional( + .rest( "grouper", SyntaxShape::OneOf(vec![ SyntaxShape::CellPath, @@ -135,7 +135,89 @@ impl Command for GroupBy { Value::test_string("false"), ]), })), - } + }, + Example { + description: "Group items by multiple columns' values", + example: r#"[ + [name, lang, year]; + [andres, rb, "2019"], + [jt, rs, "2019"], + [storm, rs, "2021"] + ] + | group-by lang year"#, + result: Some(Value::test_record(record! { + "rb" => Value::test_record(record! { + "2019" => Value::test_list( + vec![Value::test_record(record! { + "name" => Value::test_string("andres"), + "lang" => Value::test_string("rb"), + "year" => Value::test_string("2019"), + })], + ), + }), + "rs" => Value::test_record(record! { + "2019" => Value::test_list( + vec![Value::test_record(record! { + "name" => Value::test_string("jt"), + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2019"), + })], + ), + "2021" => Value::test_list( + vec![Value::test_record(record! { + "name" => Value::test_string("storm"), + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2021"), + })], + ), + }), + })) + }, + Example { + description: "Group items by multiple columns' values", + example: r#"[ + [name, lang, year]; + [andres, rb, "2019"], + [jt, rs, "2019"], + [storm, rs, "2021"] + ] + | group-by lang year --to-table"#, + result: Some(Value::test_list(vec![ + Value::test_record(record! { + "lang" => Value::test_string("rb"), + "year" => Value::test_string("2019"), + "items" => Value::test_list(vec![ + Value::test_record(record! { + "name" => Value::test_string("andres"), + "lang" => Value::test_string("rb"), + "year" => Value::test_string("2019"), + }) + ]), + }), + Value::test_record(record! { + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2019"), + "items" => Value::test_list(vec![ + Value::test_record(record! { + "name" => Value::test_string("jt"), + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2019"), + }) + ]), + }), + Value::test_record(record! { + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2021"), + "items" => Value::test_list(vec![ + Value::test_record(record! { + "name" => Value::test_string("storm"), + "lang" => Value::test_string("rs"), + "year" => Value::test_string("2021"), + }) + ]), + }), + ])) + }, ] } } @@ -147,7 +229,7 @@ pub fn group_by( input: PipelineData, ) -> Result { let head = call.head; - let grouper: Option = call.opt(engine_state, stack, 0)?; + let groupers: Vec = call.rest(engine_state, stack, 0)?; let to_table = call.has_flag(engine_state, stack, "to-table")?; let config = engine_state.get_config(); @@ -156,29 +238,22 @@ pub fn group_by( return Ok(Value::record(Record::new(), head).into_pipeline_data()); } - let groups = match grouper { - Some(grouper) => { - let span = grouper.span(); - match grouper { - Value::CellPath { val, .. } => group_cell_path(val, values, config)?, - Value::Closure { val, .. } => { - group_closure(values, span, *val, engine_state, stack)? - } - _ => { - return Err(ShellError::TypeMismatch { - err_message: "unsupported grouper type".to_string(), - span, - }) - } - } + let mut groupers = groupers.into_iter(); + + let grouped = if let Some(grouper) = groupers.next() { + let mut groups = Grouped::new(&grouper, values, config, engine_state, stack)?; + for grouper in groupers { + groups.subgroup(&grouper, config, engine_state, stack)?; } - None => group_no_grouper(values, config)?, + groups + } else { + Grouped::empty(values, config) }; let value = if to_table { - groups_to_table(groups, head) + grouped.into_table(head) } else { - groups_to_record(groups, head) + grouped.into_record(head) }; Ok(value.into_pipeline_data()) @@ -207,20 +282,6 @@ fn group_cell_path( Ok(groups) } -fn group_no_grouper( - values: Vec, - config: &nu_protocol::Config, -) -> Result>, ShellError> { - let mut groups = IndexMap::<_, Vec<_>>::new(); - - for value in values.into_iter() { - let key = value.to_abbreviated_string(config); - groups.entry(key).or_default().push(value); - } - - Ok(groups) -} - fn group_closure( values: Vec, span: Span, @@ -244,32 +305,137 @@ fn group_closure( Ok(groups) } -fn groups_to_record(groups: IndexMap>, span: Span) -> Value { - Value::record( - groups - .into_iter() - .map(|(k, v)| (k, Value::list(v, span))) - .collect(), - span, - ) +struct Grouped { + grouper: Option, + groups: Tree, } -fn groups_to_table(groups: IndexMap>, span: Span) -> Value { - Value::list( - groups - .into_iter() - .map(|(group, items)| { - Value::record( - record! { - "group" => Value::string(group, span), - "items" => Value::list(items, span), - }, +enum Tree { + Leaf(IndexMap>), + Branch(IndexMap), +} + +impl Grouped { + fn empty(values: Vec, config: &nu_protocol::Config) -> Self { + let mut groups = IndexMap::<_, Vec<_>>::new(); + + for value in values.into_iter() { + let key = value.to_abbreviated_string(config); + groups.entry(key).or_default().push(value); + } + + Self { + grouper: Some("group".into()), + groups: Tree::Leaf(groups), + } + } + + fn new( + grouper: &Value, + values: Vec, + config: &nu_protocol::Config, + engine_state: &EngineState, + stack: &mut Stack, + ) -> Result { + let span = grouper.span(); + let groups = match grouper { + Value::CellPath { val, .. } => group_cell_path(val.clone(), values, config)?, + Value::Closure { val, .. } => { + group_closure(values, span, Closure::clone(val), engine_state, stack)? + } + _ => { + return Err(ShellError::TypeMismatch { + err_message: "unsupported grouper type".to_string(), span, - ) - }) - .collect(), - span, - ) + }) + } + }; + let grouper = grouper.as_cell_path().ok().map(CellPath::to_column_name); + Ok(Self { + grouper, + groups: Tree::Leaf(groups), + }) + } + + fn subgroup( + &mut self, + grouper: &Value, + config: &nu_protocol::Config, + engine_state: &EngineState, + stack: &mut Stack, + ) -> Result<(), ShellError> { + let groups = match &mut self.groups { + Tree::Leaf(groups) => std::mem::take(groups) + .into_iter() + .map(|(key, values)| -> Result<_, ShellError> { + let leaf = Self::new(grouper, values, config, engine_state, stack)?; + Ok((key, leaf)) + }) + .collect::, ShellError>>()?, + Tree::Branch(nested_groups) => { + let mut nested_groups = std::mem::take(nested_groups); + for v in nested_groups.values_mut() { + v.subgroup(grouper, config, engine_state, stack)?; + } + nested_groups + } + }; + self.groups = Tree::Branch(groups); + Ok(()) + } + + fn into_table(self, head: Span) -> Value { + self._into_table(head, 0) + .into_iter() + .map(|row| row.into_iter().rev().collect::().into_value(head)) + .collect::>() + .into_value(head) + } + + fn _into_table(self, head: Span, index: usize) -> Vec { + let grouper = self.grouper.unwrap_or_else(|| format!("group{index}")); + match self.groups { + Tree::Leaf(leaf) => leaf + .into_iter() + .map(|(group, values)| { + [ + ("items".to_string(), values.into_value(head)), + (grouper.clone(), group.into_value(head)), + ] + .into_iter() + .collect() + }) + .collect::>(), + Tree::Branch(branch) => branch + .into_iter() + .flat_map(|(group, items)| { + let mut inner = items._into_table(head, index + 1); + for row in &mut inner { + row.insert(grouper.clone(), group.clone().into_value(head)); + } + inner + }) + .collect(), + } + } + + fn into_record(self, head: Span) -> Value { + match self.groups { + Tree::Leaf(leaf) => Value::record( + leaf.into_iter() + .map(|(k, v)| (k, v.into_value(head))) + .collect(), + head, + ), + Tree::Branch(branch) => { + let values = branch + .into_iter() + .map(|(k, v)| (k, v.into_record(head))) + .collect(); + Value::record(values, head) + } + } + } } #[cfg(test)] diff --git a/crates/nu-std/testing.nu b/crates/nu-std/testing.nu index 003e1d1ebc..7052983a21 100644 --- a/crates/nu-std/testing.nu +++ b/crates/nu-std/testing.nu @@ -79,7 +79,7 @@ def create-test-record [] nothing -> record