diff --git a/crates/nu_plugin_shuffle/src/nu/mod.rs b/crates/nu_plugin_shuffle/src/nu/mod.rs index e9eb975cc0..287fb0ac7c 100644 --- a/crates/nu_plugin_shuffle/src/nu/mod.rs +++ b/crates/nu_plugin_shuffle/src/nu/mod.rs @@ -1,6 +1,8 @@ use nu_errors::ShellError; use nu_plugin::Plugin; -use nu_protocol::{ReturnValue, Signature, Value}; +use nu_protocol::{ + CallInfo, ReturnSuccess, ReturnValue, Signature, SyntaxShape, UntaggedValue, Value, +}; use rand::seq::SliceRandom; use rand::thread_rng; @@ -8,21 +10,42 @@ use rand::thread_rng; #[derive(Default)] pub struct Shuffle { values: Vec, + limit: Option, } impl Shuffle { pub fn new() -> Self { Self::default() } + + pub fn setup(&mut self, call_info: CallInfo) -> ReturnValue { + self.limit = if let Some(value) = call_info.args.get("num") { + Some(value.as_u64()?) + } else { + None + }; + ReturnSuccess::value(UntaggedValue::nothing().into_untagged_value()) + } } impl Plugin for Shuffle { fn config(&mut self) -> Result { Ok(Signature::build("shuffle") .desc("Shuffle input randomly") + .named( + "num", + SyntaxShape::Int, + "Limit output to `num` number of values", + Some('n'), + ) .filter()) } + fn begin_filter(&mut self, call_info: CallInfo) -> Result, ShellError> { + self.setup(call_info)?; + Ok(vec![]) + } + fn filter(&mut self, input: Value) -> Result, ShellError> { self.values.push(input.into()); Ok(vec![]) @@ -30,7 +53,12 @@ impl Plugin for Shuffle { fn end_filter(&mut self) -> Result, ShellError> { let mut rng = thread_rng(); - self.values.shuffle(&mut rng); - Ok(self.values.clone()) + if let Some(n) = self.limit { + let (shuffled, _) = self.values.partial_shuffle(&mut rng, n as usize); + Ok(shuffled.to_vec()) + } else { + self.values.shuffle(&mut rng); + Ok(self.values.clone()) + } } }