add a threads parameter to par_each (#8679)

# Description

This PR allows you to control the amount of threads that `par-each` uses
via a `--threads(-t)` parameter. When no threads parameter is specified,
`par-each` uses the default, which is the same number of available CPUs
on your system.


![image](https://user-images.githubusercontent.com/343840/228935152-eca5b06b-4e8d-41be-82c4-ecd49cdf1fe1.png)

closes #4407

# User-Facing Changes

New parameter

# Tests + Formatting

Don't forget to add tests that cover your changes.

Make sure you've run and fixed any issues with these commands:

- `cargo fmt --all -- --check` to check standard code formatting (`cargo
fmt --all` applies these changes)
- `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used -A
clippy::needless_collect` to check that you're using the standard code
style
- `cargo test --workspace` to check that all tests pass
- `cargo run -- crates/nu-utils/standard_library/tests.nu` to run the
tests for the standard library

> **Note**
> from `nushell` you can also use the `toolkit` as follows
> ```bash
> use toolkit.nu # or use an `env_change` hook to activate it
automatically
> toolkit check pr
> ```

# After Submitting

If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
This commit is contained in:
Darren Schroeder 2023-03-30 16:39:40 -05:00 committed by GitHub
parent 0e496f900d
commit 09276db2a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -30,6 +30,12 @@ impl Command for ParEach {
),
(Type::Table(vec![]), Type::List(Box::new(Type::Any))),
])
.named(
"threads",
SyntaxShape::Int,
"the number of threads to use",
Some('t'),
)
.required(
"closure",
SyntaxShape::Closure(Some(vec![SyntaxShape::Any, SyntaxShape::Int])),
@ -85,8 +91,27 @@ impl Command for ParEach {
call: &Call,
input: PipelineData,
) -> Result<PipelineData, ShellError> {
let capture_block: Closure = call.req(engine_state, stack, 0)?;
fn create_pool(num_threads: usize) -> Result<rayon::ThreadPool, ShellError> {
match rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
{
Err(e) => Err(e).map_err(|e| {
ShellError::GenericError(
"Error creating thread pool".into(),
e.to_string(),
Some(Span::unknown()),
None,
Vec::new(),
)
}),
Ok(pool) => Ok(pool),
}
}
let capture_block: Closure = call.req(engine_state, stack, 0)?;
let threads: Option<usize> = call.get_flag(engine_state, stack, "threads")?;
let max_threads = threads.unwrap_or(0);
let metadata = input.metadata();
let ctrlc = engine_state.ctrlc.clone();
let block_id = capture_block.block_id;
@ -96,8 +121,10 @@ impl Command for ParEach {
match input {
PipelineData::Empty => Ok(PipelineData::Empty),
PipelineData::Value(Value::Range { val, .. }, ..) => Ok(val
.into_range_iter(ctrlc.clone())?
PipelineData::Value(Value::Range { val, .. }, ..) => Ok(create_pool(max_threads)?
.install(|| {
val.into_range_iter(ctrlc.clone())
.expect("unable to create a range iterator")
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
@ -129,9 +156,45 @@ impl Command for ParEach {
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(val
.into_pipeline_data(ctrlc)
})),
PipelineData::Value(Value::List { vals: val, .. }, ..) => Ok(create_pool(max_threads)?
.install(|| {
val.par_iter()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.clone().into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)
})),
PipelineData::ListStream(stream, ..) => Ok(create_pool(max_threads)?.install(|| {
stream
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
@ -163,45 +226,14 @@ impl Command for ParEach {
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
PipelineData::ListStream(stream, ..) => Ok(stream
.par_bridge()
.map(move |x| {
let block = engine_state.get_block(block_id);
let mut stack = stack.clone();
if let Some(var) = block.signature.get_positional(0) {
if let Some(var_id) = &var.var_id {
stack.add_var(*var_id, x.clone());
}
}
let val_span = x.span();
match eval_block_with_early_return(
engine_state,
&mut stack,
block,
x.into_pipeline_data(),
redirect_stdout,
redirect_stderr,
) {
Ok(v) => v,
Err(error) => Value::Error {
error: Box::new(chain_error_with_input(error, val_span)),
}
.into_pipeline_data(),
}
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
.into_pipeline_data(ctrlc)
})),
PipelineData::ExternalStream { stdout: None, .. } => Ok(PipelineData::empty()),
PipelineData::ExternalStream {
stdout: Some(stream),
..
} => Ok(stream
} => Ok(create_pool(max_threads)?.install(|| {
stream
.par_bridge()
.map(move |x| {
let x = match x {
@ -242,10 +274,12 @@ impl Command for ParEach {
.collect::<Vec<_>>()
.into_iter()
.flatten()
.into_pipeline_data(ctrlc)),
.into_pipeline_data(ctrlc)
})),
// This match allows non-iterables to be accepted,
// which is currently considered undesirable (Nov 2022).
PipelineData::Value(x, ..) => {
eprint!("value");
let block = engine_state.get_block(block_id);
if let Some(var) = block.signature.get_positional(0) {