diff --git a/crates/nu-cmd-extra/src/extra/bits/mod.rs b/crates/nu-cmd-extra/src/extra/bits/mod.rs index 6d1200a6bf..145d66d777 100644 --- a/crates/nu-cmd-extra/src/extra/bits/mod.rs +++ b/crates/nu-cmd-extra/src/extra/bits/mod.rs @@ -44,6 +44,25 @@ enum InputNumType { SignedEight, } +impl InputNumType { + fn num_bits(self) -> u32 { + match self { + InputNumType::One => 8, + InputNumType::Two => 16, + InputNumType::Four => 32, + InputNumType::Eight => 64, + InputNumType::SignedOne => 8, + InputNumType::SignedTwo => 16, + InputNumType::SignedFour => 32, + InputNumType::SignedEight => 64, + } + } + + fn is_permitted_bit_shift(self, bits: u32) -> bool { + bits < self.num_bits() + } +} + fn get_number_bytes( number_bytes: Option>, head: Span, diff --git a/crates/nu-cmd-extra/src/extra/bits/shift_left.rs b/crates/nu-cmd-extra/src/extra/bits/shift_left.rs index 6a67a45e0e..e3b484a7d4 100644 --- a/crates/nu-cmd-extra/src/extra/bits/shift_left.rs +++ b/crates/nu-cmd-extra/src/extra/bits/shift_left.rs @@ -7,7 +7,7 @@ use std::iter; struct Arguments { signed: bool, - bits: usize, + bits: Spanned, number_size: NumberBytes, } @@ -71,7 +71,9 @@ impl Command for BitsShl { input: PipelineData, ) -> Result { let head = call.head; - let bits: usize = call.req(engine_state, stack, 0)?; + // This restricts to a positive shift value (our underlying operations do not + // permit them) + let bits: Spanned = call.req(engine_state, stack, 0)?; let signed = call.has_flag(engine_state, stack, "signed")?; let number_bytes: Option> = call.get_flag(engine_state, stack, "number-bytes")?; @@ -131,14 +133,29 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { number_size, bits, } = *args; + let bits_span = bits.span; + let bits = bits.item; match input { Value::Int { val, .. } => { use InputNumType::*; let val = *val; - let bits = bits as u64; + let bits = bits as u32; let input_num_type = get_input_num_type(val, signed, number_size); + if !input_num_type.is_permitted_bit_shift(bits) { + return Value::error( + ShellError::IncorrectValue { + msg: format!( + "Trying to shift by more than the available bits (permitted < {})", + input_num_type.num_bits() + ), + val_span: bits_span, + call_span: span, + }, + span, + ); + } let int = match input_num_type { One => ((val as u8) << bits) as i64, Two => ((val as u16) << bits) as i64, @@ -147,12 +164,14 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { let Ok(i) = i64::try_from((val as u64) << bits) else { return Value::error( ShellError::GenericError { - error: "result out of range for specified number".into(), + error: "result out of range for int".into(), msg: format!( "shifting left by {bits} is out of range for the value {val}" ), span: Some(span), - help: None, + help: Some( + "Ensure the result fits in a 64-bit signed integer.".into(), + ), inner: vec![], }, span, @@ -172,19 +191,26 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { let byte_shift = bits / 8; let bit_shift = bits % 8; - use itertools::Position::*; - let bytes = val - .iter() - .copied() - .skip(byte_shift) - .circular_tuple_windows::<(u8, u8)>() - .with_position() - .map(|(pos, (lhs, rhs))| match pos { - Last | Only => lhs << bit_shift, - _ => (lhs << bit_shift) | (rhs >> bit_shift), - }) - .chain(iter::repeat(0).take(byte_shift)) - .collect::>(); + // This is purely for symmetry with the int case and the fact that the + // shift right implementation in its current form panicked with an overflow + if bits > val.len() * 8 { + return Value::error( + ShellError::IncorrectValue { + msg: format!( + "Trying to shift by more than the available bits ({})", + val.len() * 8 + ), + val_span: bits_span, + call_span: span, + }, + span, + ); + } + let bytes = if bit_shift == 0 { + shift_bytes_left(val, byte_shift) + } else { + shift_bytes_and_bits_left(val, byte_shift, bit_shift) + }; Value::binary(bytes, span) } @@ -202,6 +228,31 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { } } +fn shift_bytes_left(data: &[u8], byte_shift: usize) -> Vec { + let len = data.len(); + let mut output = vec![0; len]; + output[..len - byte_shift].copy_from_slice(&data[byte_shift..]); + output +} + +fn shift_bytes_and_bits_left(data: &[u8], byte_shift: usize, bit_shift: usize) -> Vec { + use itertools::Position::*; + debug_assert!((1..8).contains(&bit_shift), + "Bit shifts of 0 can't be handled by this impl and everything else should be part of the byteshift" + ); + data.iter() + .copied() + .skip(byte_shift) + .circular_tuple_windows::<(u8, u8)>() + .with_position() + .map(|(pos, (lhs, rhs))| match pos { + Last | Only => lhs << bit_shift, + _ => (lhs << bit_shift) | (rhs >> (8 - bit_shift)), + }) + .chain(iter::repeat(0).take(byte_shift)) + .collect::>() +} + #[cfg(test)] mod test { use super::*; diff --git a/crates/nu-cmd-extra/src/extra/bits/shift_right.rs b/crates/nu-cmd-extra/src/extra/bits/shift_right.rs index e45e10ac94..9bb5b1563a 100644 --- a/crates/nu-cmd-extra/src/extra/bits/shift_right.rs +++ b/crates/nu-cmd-extra/src/extra/bits/shift_right.rs @@ -1,13 +1,10 @@ use super::{get_input_num_type, get_number_bytes, InputNumType, NumberBytes}; -use itertools::Itertools; use nu_cmd_base::input_handler::{operate, CmdArgument}; use nu_engine::command_prelude::*; -use std::iter; - struct Arguments { signed: bool, - bits: usize, + bits: Spanned, number_size: NumberBytes, } @@ -71,7 +68,9 @@ impl Command for BitsShr { input: PipelineData, ) -> Result { let head = call.head; - let bits: usize = call.req(engine_state, stack, 0)?; + // This restricts to a positive shift value (our underlying operations do not + // permit them) + let bits: Spanned = call.req(engine_state, stack, 0)?; let signed = call.has_flag(engine_state, stack, "signed")?; let number_bytes: Option> = call.get_flag(engine_state, stack, "number-bytes")?; @@ -121,6 +120,8 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { number_size, bits, } = *args; + let bits_span = bits.span; + let bits = bits.item; match input { Value::Int { val, .. } => { @@ -129,6 +130,19 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { let bits = bits as u32; let input_num_type = get_input_num_type(val, signed, number_size); + if !input_num_type.is_permitted_bit_shift(bits) { + return Value::error( + ShellError::IncorrectValue { + msg: format!( + "Trying to shift by more than the available bits (permitted < {})", + input_num_type.num_bits() + ), + val_span: bits_span, + call_span: span, + }, + span, + ); + } let int = match input_num_type { One => ((val as u8) >> bits) as i64, Two => ((val as u16) >> bits) as i64, @@ -147,21 +161,27 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { let bit_shift = bits % 8; let len = val.len(); - use itertools::Position::*; - let bytes = iter::repeat(0) - .take(byte_shift) - .chain( - val.iter() - .copied() - .circular_tuple_windows::<(u8, u8)>() - .with_position() - .map(|(pos, (lhs, rhs))| match pos { - First | Only => lhs >> bit_shift, - _ => (lhs >> bit_shift) | (rhs << bit_shift), - }) - .take(len - byte_shift), - ) - .collect::>(); + // This check is done for symmetry with the int case and the previous + // implementation would overflow byte indices leading to unexpected output + // lengths + if bits > len * 8 { + return Value::error( + ShellError::IncorrectValue { + msg: format!( + "Trying to shift by more than the available bits ({})", + len * 8 + ), + val_span: bits_span, + call_span: span, + }, + span, + ); + } + let bytes = if bit_shift == 0 { + shift_bytes_right(val, byte_shift) + } else { + shift_bytes_and_bits_right(val, byte_shift, bit_shift) + }; Value::binary(bytes, span) } @@ -178,6 +198,35 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value { ), } } +fn shift_bytes_right(data: &[u8], byte_shift: usize) -> Vec { + let len = data.len(); + let mut output = vec![0; len]; + output[byte_shift..].copy_from_slice(&data[..len - byte_shift]); + output +} + +fn shift_bytes_and_bits_right(data: &[u8], byte_shift: usize, bit_shift: usize) -> Vec { + debug_assert!( + bit_shift > 0 && bit_shift < 8, + "bit_shift should be in the range (0, 8)" + ); + let len = data.len(); + let mut output = vec![0; len]; + + for i in byte_shift..len { + let shifted_bits = data[i - byte_shift] >> bit_shift; + let carried_bits = if i > byte_shift { + data[i - byte_shift - 1] << (8 - bit_shift) + } else { + 0 + }; + let shifted_byte = shifted_bits | carried_bits; + + output[i] = shifted_byte; + } + + output +} #[cfg(test)] mod test { diff --git a/crates/nu-protocol/src/value/mod.rs b/crates/nu-protocol/src/value/mod.rs index edd51d70a6..3620be698c 100644 --- a/crates/nu-protocol/src/value/mod.rs +++ b/crates/nu-protocol/src/value/mod.rs @@ -3317,7 +3317,18 @@ impl Value { pub fn bit_shl(&self, op: Span, rhs: &Value, span: Span) -> Result { match (self, rhs) { (Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => { - Ok(Value::int(*lhs << rhs, span)) + // Currently we disallow negative operands like Rust's `Shl` + // Cheap guarding with TryInto + if let Some(val) = (*rhs).try_into().ok().and_then(|rhs| lhs.checked_shl(rhs)) { + Ok(Value::int(val, span)) + } else { + Err(ShellError::OperatorOverflow { + msg: "right operand to bit-shl exceeds available bits in underlying data" + .into(), + span, + help: format!("Limit operand to 0 <= rhs < {}", i64::BITS), + }) + } } (Value::Custom { val: lhs, .. }, rhs) => { lhs.operation(span, Operator::Bits(Bits::ShiftLeft), op, rhs) @@ -3335,7 +3346,18 @@ impl Value { pub fn bit_shr(&self, op: Span, rhs: &Value, span: Span) -> Result { match (self, rhs) { (Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => { - Ok(Value::int(*lhs >> rhs, span)) + // Currently we disallow negative operands like Rust's `Shr` + // Cheap guarding with TryInto + if let Some(val) = (*rhs).try_into().ok().and_then(|rhs| lhs.checked_shr(rhs)) { + Ok(Value::int(val, span)) + } else { + Err(ShellError::OperatorOverflow { + msg: "right operand to bit-shr exceeds available bits in underlying data" + .into(), + span, + help: format!("Limit operand to 0 <= rhs < {}", i64::BITS), + }) + } } (Value::Custom { val: lhs, .. }, rhs) => { lhs.operation(span, Operator::Bits(Bits::ShiftRight), op, rhs) diff --git a/tests/repl/test_bits.rs b/tests/repl/test_bits.rs index 410f48d83b..8848ea9f98 100644 --- a/tests/repl/test_bits.rs +++ b/tests/repl/test_bits.rs @@ -1,4 +1,4 @@ -use crate::repl::tests::{run_test, TestResult}; +use crate::repl::tests::{fail_test, run_test, TestResult}; #[test] fn bits_and() -> TestResult { @@ -56,6 +56,33 @@ fn bits_shift_left() -> TestResult { run_test("2 | bits shl 3", "16") } +#[test] +fn bits_shift_left_negative_operand() -> TestResult { + fail_test("8 | bits shl -2", "positive value") +} + +#[test] +fn bits_shift_left_exceeding1() -> TestResult { + // We have no type accepting more than 64 bits so guaranteed fail + fail_test("8 | bits shl 65", "more than the available bits") +} + +#[test] +fn bits_shift_left_exceeding2() -> TestResult { + // Explicitly specifying 2 bytes, but 16 is already the max + fail_test( + "8 | bits shl --number-bytes 2 16", + "more than the available bits", + ) +} + +#[test] +fn bits_shift_left_exceeding3() -> TestResult { + // This is purely down to the current autodetect feature limiting to the smallest integer + // type thus assuming a u8 + fail_test("8 | bits shl 9", "more than the available bits") +} + #[test] fn bits_shift_left_negative() -> TestResult { run_test("-3 | bits shl 5", "-96") @@ -69,11 +96,79 @@ fn bits_shift_left_list() -> TestResult { ) } +#[test] +fn bits_shift_left_binary1() -> TestResult { + run_test( + "0x[01 30 80] | bits shl 3 | into bits", + "00001001 10000100 00000000", + ) +} + +#[test] +fn bits_shift_left_binary2() -> TestResult { + // Whole byte case + run_test( + "0x[01 30 80] | bits shl 8 | into bits", + "00110000 10000000 00000000", + ) +} + +#[test] +fn bits_shift_left_binary3() -> TestResult { + // Compared to the int case this is made inclusive of the bit count + run_test( + "0x[01 30 80] | bits shl 24 | into bits", + "00000000 00000000 00000000", + ) +} + +#[test] +fn bits_shift_left_binary4() -> TestResult { + // Shifting by both bytes and bits + run_test( + "0x[01 30 80] | bits shl 15 | into bits", + "01000000 00000000 00000000", + ) +} + +#[test] +fn bits_shift_left_binary_exceeding() -> TestResult { + // Compared to the int case this is made inclusive of the bit count + fail_test("0x[01 30] | bits shl 17 | into bits", "") +} + #[test] fn bits_shift_right() -> TestResult { run_test("8 | bits shr 2", "2") } +#[test] +fn bits_shift_right_negative_operand() -> TestResult { + fail_test("8 | bits shr -2", "positive value") +} + +#[test] +fn bits_shift_right_exceeding1() -> TestResult { + // We have no type accepting more than 64 bits so guaranteed fail + fail_test("8 | bits shr 65", "more than the available bits") +} + +#[test] +fn bits_shift_right_exceeding2() -> TestResult { + // Explicitly specifying 2 bytes, but 16 is already the max + fail_test( + "8 | bits shr --number-bytes 2 16", + "more than the available bits", + ) +} + +#[test] +fn bits_shift_right_exceeding3() -> TestResult { + // This is purely down to the current autodetect feature limiting to the smallest integer + // type thus assuming a u8 + fail_test("8 | bits shr 9", "more than the available bits") +} + #[test] fn bits_shift_right_negative() -> TestResult { run_test("-32 | bits shr 2", "-8") @@ -87,6 +182,47 @@ fn bits_shift_right_list() -> TestResult { ) } +#[test] +fn bits_shift_right_binary1() -> TestResult { + run_test( + "0x[01 30 80] | bits shr 3 | into bits", + "00000000 00100110 00010000", + ) +} + +#[test] +fn bits_shift_right_binary2() -> TestResult { + // Whole byte case + run_test( + "0x[01 30 80] | bits shr 8 | into bits", + "00000000 00000001 00110000", + ) +} + +#[test] +fn bits_shift_right_binary3() -> TestResult { + // Compared to the int case this is made inclusive of the bit count + run_test( + "0x[01 30 80] | bits shr 24 | into bits", + "00000000 00000000 00000000", + ) +} + +#[test] +fn bits_shift_right_binary4() -> TestResult { + // Shifting by both bytes and bits + run_test( + "0x[01 30 80] | bits shr 15 | into bits", + "00000000 00000000 00000010", + ) +} + +#[test] +fn bits_shift_right_binary_exceeding() -> TestResult { + // Compared to the int case this is made inclusive of the bit count + fail_test("0x[01 30] | bits shr 17 | into bits", "available bits (16)") +} + #[test] fn bits_rotate_left() -> TestResult { run_test("2 | bits rol 3", "16") diff --git a/tests/repl/test_math.rs b/tests/repl/test_math.rs index af44f8a857..7ad23e1968 100644 --- a/tests/repl/test_math.rs +++ b/tests/repl/test_math.rs @@ -35,6 +35,31 @@ fn bit_shl() -> TestResult { run_test("5 bit-shl 1", "10") } +#[test] +fn bit_shr_overflow() -> TestResult { + fail_test("16 bit-shr 10000", "exceeds available bits") +} + +#[test] +fn bit_shl_overflow() -> TestResult { + fail_test("5 bit-shl 10000000", "exceeds available bits") +} + +#[test] +fn bit_shl_neg_operand() -> TestResult { + // This would overflow the `u32` in the right hand side to 2 + fail_test( + "9 bit-shl -9_223_372_036_854_775_806", + "exceeds available bits", + ) +} + +#[test] +fn bit_shr_neg_operand() -> TestResult { + // This would overflow the `u32` in the right hand side + fail_test("9 bit-shr -2", "exceeds available bits") +} + #[test] fn bit_shl_add() -> TestResult { run_test("2 bit-shl 1 + 2", "16")