From 09276db2a5dcc7157b63f330cc56d668dd51735d Mon Sep 17 00:00:00 2001 From: Darren Schroeder <343840+fdncred@users.noreply.github.com> Date: Thu, 30 Mar 2023 16:39:40 -0500 Subject: [PATCH] add a `threads` parameter to `par_each` (#8679) # Description This PR allows you to control the amount of threads that `par-each` uses via a `--threads(-t)` parameter. When no threads parameter is specified, `par-each` uses the default, which is the same number of available CPUs on your system. ![image](https://user-images.githubusercontent.com/343840/228935152-eca5b06b-4e8d-41be-82c4-ecd49cdf1fe1.png) closes #4407 # User-Facing Changes New parameter # Tests + Formatting Don't forget to add tests that cover your changes. Make sure you've run and fixed any issues with these commands: - `cargo fmt --all -- --check` to check standard code formatting (`cargo fmt --all` applies these changes) - `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used -A clippy::needless_collect` to check that you're using the standard code style - `cargo test --workspace` to check that all tests pass - `cargo run -- crates/nu-utils/standard_library/tests.nu` to run the tests for the standard library > **Note** > from `nushell` you can also use the `toolkit` as follows > ```bash > use toolkit.nu # or use an `env_change` hook to activate it automatically > toolkit check pr > ``` # After Submitting If your PR had any user-facing changes, update [the documentation](https://github.com/nushell/nushell.github.io) after the PR is merged, if necessary. This will help us keep the docs up to date. --- crates/nu-command/src/filters/par_each.rs | 304 ++++++++++++---------- 1 file changed, 169 insertions(+), 135 deletions(-) diff --git a/crates/nu-command/src/filters/par_each.rs b/crates/nu-command/src/filters/par_each.rs index 7fc8abfc1c..6385bbdffe 100644 --- a/crates/nu-command/src/filters/par_each.rs +++ b/crates/nu-command/src/filters/par_each.rs @@ -30,6 +30,12 @@ impl Command for ParEach { ), (Type::Table(vec![]), Type::List(Box::new(Type::Any))), ]) + .named( + "threads", + SyntaxShape::Int, + "the number of threads to use", + Some('t'), + ) .required( "closure", SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Int])), @@ -85,8 +91,27 @@ impl Command for ParEach { call: &Call, input: PipelineData, ) -> Result { - let capture_block: Closure = call.req(engine_state, stack, 0)?; + fn create_pool(num_threads: usize) -> Result { + match rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + { + Err(e) => Err(e).map_err(|e| { + ShellError::GenericError( + "Error creating thread pool".into(), + e.to_string(), + Some(Span::unknown()), + None, + Vec::new(), + ) + }), + Ok(pool) => Ok(pool), + } + } + let capture_block: Closure = call.req(engine_state, stack, 0)?; + let threads: Option = call.get_flag(engine_state, stack, "threads")?; + let max_threads = threads.unwrap_or(0); let metadata = input.metadata(); let ctrlc = engine_state.ctrlc.clone(); let block_id = capture_block.block_id; @@ -96,156 +121,165 @@ impl Command for ParEach { match input { PipelineData::Empty => Ok(PipelineData::Empty), - PipelineData::Value(Value::Range { val, .. }, ..) => Ok(val - .into_range_iter(ctrlc.clone())? - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); + PipelineData::Value(Value::Range { val, .. }, ..) => Ok(create_pool(max_threads)? + .install(|| { + val.into_range_iter(ctrlc.clone()) + .expect("unable to create a range iterator") + .par_bridge() + .map(move |x| { + let block = engine_state.get_block(block_id); - let mut stack = stack.clone(); + let mut stack = stack.clone(); - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } + } + + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), + PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(create_pool(max_threads)? + .install(|| { + val.par_iter() + .map(move |x| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } + } + + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.clone().into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), + PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| { + stream + .par_bridge() + .map(move |x| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } } - } - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), - PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(val - .into_iter() - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); - } - } - - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), - PipelineData::ListStream(stream, ..) => Ok(stream - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); - } - } - - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), PipelineData::ExternalStream { stdout: None, .. } => Ok(PipelineData::empty()), PipelineData::ExternalStream { stdout: Some(stream), .. - } => Ok(stream - .par_bridge() - .map(move |x| { - let x = match x { - Ok(x) => x, - Err(err) => { - return Value::Error { - error: Box::new(err), + } => Ok(create_pool(max_threads)?.install(|| { + stream + .par_bridge() + .map(move |x| { + let x = match x { + Ok(x) => x, + Err(err) => { + return Value::Error { + error: Box::new(err), + } + .into_pipeline_data() + } + }; + + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); } - .into_pipeline_data() } - }; - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(error), + } + .into_pipeline_data(), } - } - - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(error), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), // This match allows non-iterables to be accepted, // which is currently considered undesirable (Nov 2022). PipelineData::Value(x, ..) => { + eprint!("value"); let block = engine_state.get_block(block_id); if let Some(var) = block.signature.get_positional(0) {