diff --git a/crates/nu-command/src/filters/par_each.rs b/crates/nu-command/src/filters/par_each.rs index 7fc8abfc1c..6385bbdffe 100644 --- a/crates/nu-command/src/filters/par_each.rs +++ b/crates/nu-command/src/filters/par_each.rs @@ -30,6 +30,12 @@ impl Command for ParEach { ), (Type::Table(vec![]), Type::List(Box::new(Type::Any))), ]) + .named( + "threads", + SyntaxShape::Int, + "the number of threads to use", + Some('t'), + ) .required( "closure", SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Int])), @@ -85,8 +91,27 @@ impl Command for ParEach { call: &Call, input: PipelineData, ) -> Result { - let capture_block: Closure = call.req(engine_state, stack, 0)?; + fn create_pool(num_threads: usize) -> Result { + match rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + { + Err(e) => Err(e).map_err(|e| { + ShellError::GenericError( + "Error creating thread pool".into(), + e.to_string(), + Some(Span::unknown()), + None, + Vec::new(), + ) + }), + Ok(pool) => Ok(pool), + } + } + let capture_block: Closure = call.req(engine_state, stack, 0)?; + let threads: Option = call.get_flag(engine_state, stack, "threads")?; + let max_threads = threads.unwrap_or(0); let metadata = input.metadata(); let ctrlc = engine_state.ctrlc.clone(); let block_id = capture_block.block_id; @@ -96,156 +121,165 @@ impl Command for ParEach { match input { PipelineData::Empty => Ok(PipelineData::Empty), - PipelineData::Value(Value::Range { val, .. }, ..) => Ok(val - .into_range_iter(ctrlc.clone())? - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); + PipelineData::Value(Value::Range { val, .. }, ..) => Ok(create_pool(max_threads)? + .install(|| { + val.into_range_iter(ctrlc.clone()) + .expect("unable to create a range iterator") + .par_bridge() + .map(move |x| { + let block = engine_state.get_block(block_id); - let mut stack = stack.clone(); + let mut stack = stack.clone(); - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } + } + + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), + PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(create_pool(max_threads)? + .install(|| { + val.par_iter() + .map(move |x| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } + } + + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.clone().into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), + } + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), + PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| { + stream + .par_bridge() + .map(move |x| { + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); + } } - } - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), + let val_span = x.span(); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(chain_error_with_input(error, val_span)), + } + .into_pipeline_data(), } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), - PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(val - .into_iter() - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); - } - } - - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), - PipelineData::ListStream(stream, ..) => Ok(stream - .par_bridge() - .map(move |x| { - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); - } - } - - let val_span = x.span(); - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(chain_error_with_input(error, val_span)), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), PipelineData::ExternalStream { stdout: None, .. } => Ok(PipelineData::empty()), PipelineData::ExternalStream { stdout: Some(stream), .. - } => Ok(stream - .par_bridge() - .map(move |x| { - let x = match x { - Ok(x) => x, - Err(err) => { - return Value::Error { - error: Box::new(err), + } => Ok(create_pool(max_threads)?.install(|| { + stream + .par_bridge() + .map(move |x| { + let x = match x { + Ok(x) => x, + Err(err) => { + return Value::Error { + error: Box::new(err), + } + .into_pipeline_data() + } + }; + + let block = engine_state.get_block(block_id); + + let mut stack = stack.clone(); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, x.clone()); } - .into_pipeline_data() } - }; - let block = engine_state.get_block(block_id); - - let mut stack = stack.clone(); - - if let Some(var) = block.signature.get_positional(0) { - if let Some(var_id) = &var.var_id { - stack.add_var(*var_id, x.clone()); + match eval_block_with_early_return( + engine_state, + &mut stack, + block, + x.into_pipeline_data(), + redirect_stdout, + redirect_stderr, + ) { + Ok(v) => v, + Err(error) => Value::Error { + error: Box::new(error), + } + .into_pipeline_data(), } - } - - match eval_block_with_early_return( - engine_state, - &mut stack, - block, - x.into_pipeline_data(), - redirect_stdout, - redirect_stderr, - ) { - Ok(v) => v, - Err(error) => Value::Error { - error: Box::new(error), - } - .into_pipeline_data(), - } - }) - .collect::>() - .into_iter() - .flatten() - .into_pipeline_data(ctrlc)), + }) + .collect::>() + .into_iter() + .flatten() + .into_pipeline_data(ctrlc) + })), // This match allows non-iterables to be accepted, // which is currently considered undesirable (Nov 2022). PipelineData::Value(x, ..) => { + eprint!("value"); let block = engine_state.get_block(block_id); if let Some(var) = block.signature.get_positional(0) {