Fix bugs and UB in bit shifting ops (#13663)

# Description
Fixes #11267

Shifting by a `shift >= num_bits` is undefined in the underlying
operation. Previously we also had an overflow on negative shifts for the
operators `bit-shl` and `bit-shr`
Furthermore I found a severe bug in the implementation of shifting of
`binary` data with the commands `bits shl` and `bits shr`, this
categorically produced incorrect results with shifts that were not
`shift % 4 == 0`. `bits shr` also was able to produce outputs with
different size to the input if the shift was exceeding the length of the
input data by more than a byte.

# User-Facing Changes
It is now an error trying to shift by more than the available bits with:
- `bit-shl` operator
- `bit-shr` operator
- command `bits shl`
- command `bits shr`

# Tests + Formatting
Added testing for all relevant cases
This commit is contained in:
Stefan Holderbach 2024-08-22 11:54:27 +02:00 committed by GitHub
parent 9261c0c55a
commit 3ab9f0b90a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 343 additions and 41 deletions

View file

@ -44,6 +44,25 @@ enum InputNumType {
SignedEight, SignedEight,
} }
impl InputNumType {
fn num_bits(self) -> u32 {
match self {
InputNumType::One => 8,
InputNumType::Two => 16,
InputNumType::Four => 32,
InputNumType::Eight => 64,
InputNumType::SignedOne => 8,
InputNumType::SignedTwo => 16,
InputNumType::SignedFour => 32,
InputNumType::SignedEight => 64,
}
}
fn is_permitted_bit_shift(self, bits: u32) -> bool {
bits < self.num_bits()
}
}
fn get_number_bytes( fn get_number_bytes(
number_bytes: Option<Spanned<usize>>, number_bytes: Option<Spanned<usize>>,
head: Span, head: Span,

View file

@ -7,7 +7,7 @@ use std::iter;
struct Arguments { struct Arguments {
signed: bool, signed: bool,
bits: usize, bits: Spanned<usize>,
number_size: NumberBytes, number_size: NumberBytes,
} }
@ -71,7 +71,9 @@ impl Command for BitsShl {
input: PipelineData, input: PipelineData,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
let head = call.head; let head = call.head;
let bits: usize = call.req(engine_state, stack, 0)?; // This restricts to a positive shift value (our underlying operations do not
// permit them)
let bits: Spanned<usize> = call.req(engine_state, stack, 0)?;
let signed = call.has_flag(engine_state, stack, "signed")?; let signed = call.has_flag(engine_state, stack, "signed")?;
let number_bytes: Option<Spanned<usize>> = let number_bytes: Option<Spanned<usize>> =
call.get_flag(engine_state, stack, "number-bytes")?; call.get_flag(engine_state, stack, "number-bytes")?;
@ -131,14 +133,29 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
number_size, number_size,
bits, bits,
} = *args; } = *args;
let bits_span = bits.span;
let bits = bits.item;
match input { match input {
Value::Int { val, .. } => { Value::Int { val, .. } => {
use InputNumType::*; use InputNumType::*;
let val = *val; let val = *val;
let bits = bits as u64; let bits = bits as u32;
let input_num_type = get_input_num_type(val, signed, number_size); let input_num_type = get_input_num_type(val, signed, number_size);
if !input_num_type.is_permitted_bit_shift(bits) {
return Value::error(
ShellError::IncorrectValue {
msg: format!(
"Trying to shift by more than the available bits (permitted < {})",
input_num_type.num_bits()
),
val_span: bits_span,
call_span: span,
},
span,
);
}
let int = match input_num_type { let int = match input_num_type {
One => ((val as u8) << bits) as i64, One => ((val as u8) << bits) as i64,
Two => ((val as u16) << bits) as i64, Two => ((val as u16) << bits) as i64,
@ -147,12 +164,14 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
let Ok(i) = i64::try_from((val as u64) << bits) else { let Ok(i) = i64::try_from((val as u64) << bits) else {
return Value::error( return Value::error(
ShellError::GenericError { ShellError::GenericError {
error: "result out of range for specified number".into(), error: "result out of range for int".into(),
msg: format!( msg: format!(
"shifting left by {bits} is out of range for the value {val}" "shifting left by {bits} is out of range for the value {val}"
), ),
span: Some(span), span: Some(span),
help: None, help: Some(
"Ensure the result fits in a 64-bit signed integer.".into(),
),
inner: vec![], inner: vec![],
}, },
span, span,
@ -172,19 +191,26 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
let byte_shift = bits / 8; let byte_shift = bits / 8;
let bit_shift = bits % 8; let bit_shift = bits % 8;
use itertools::Position::*; // This is purely for symmetry with the int case and the fact that the
let bytes = val // shift right implementation in its current form panicked with an overflow
.iter() if bits > val.len() * 8 {
.copied() return Value::error(
.skip(byte_shift) ShellError::IncorrectValue {
.circular_tuple_windows::<(u8, u8)>() msg: format!(
.with_position() "Trying to shift by more than the available bits ({})",
.map(|(pos, (lhs, rhs))| match pos { val.len() * 8
Last | Only => lhs << bit_shift, ),
_ => (lhs << bit_shift) | (rhs >> bit_shift), val_span: bits_span,
}) call_span: span,
.chain(iter::repeat(0).take(byte_shift)) },
.collect::<Vec<u8>>(); span,
);
}
let bytes = if bit_shift == 0 {
shift_bytes_left(val, byte_shift)
} else {
shift_bytes_and_bits_left(val, byte_shift, bit_shift)
};
Value::binary(bytes, span) Value::binary(bytes, span)
} }
@ -202,6 +228,31 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
} }
} }
fn shift_bytes_left(data: &[u8], byte_shift: usize) -> Vec<u8> {
let len = data.len();
let mut output = vec![0; len];
output[..len - byte_shift].copy_from_slice(&data[byte_shift..]);
output
}
fn shift_bytes_and_bits_left(data: &[u8], byte_shift: usize, bit_shift: usize) -> Vec<u8> {
use itertools::Position::*;
debug_assert!((1..8).contains(&bit_shift),
"Bit shifts of 0 can't be handled by this impl and everything else should be part of the byteshift"
);
data.iter()
.copied()
.skip(byte_shift)
.circular_tuple_windows::<(u8, u8)>()
.with_position()
.map(|(pos, (lhs, rhs))| match pos {
Last | Only => lhs << bit_shift,
_ => (lhs << bit_shift) | (rhs >> (8 - bit_shift)),
})
.chain(iter::repeat(0).take(byte_shift))
.collect::<Vec<u8>>()
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;

View file

@ -1,13 +1,10 @@
use super::{get_input_num_type, get_number_bytes, InputNumType, NumberBytes}; use super::{get_input_num_type, get_number_bytes, InputNumType, NumberBytes};
use itertools::Itertools;
use nu_cmd_base::input_handler::{operate, CmdArgument}; use nu_cmd_base::input_handler::{operate, CmdArgument};
use nu_engine::command_prelude::*; use nu_engine::command_prelude::*;
use std::iter;
struct Arguments { struct Arguments {
signed: bool, signed: bool,
bits: usize, bits: Spanned<usize>,
number_size: NumberBytes, number_size: NumberBytes,
} }
@ -71,7 +68,9 @@ impl Command for BitsShr {
input: PipelineData, input: PipelineData,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
let head = call.head; let head = call.head;
let bits: usize = call.req(engine_state, stack, 0)?; // This restricts to a positive shift value (our underlying operations do not
// permit them)
let bits: Spanned<usize> = call.req(engine_state, stack, 0)?;
let signed = call.has_flag(engine_state, stack, "signed")?; let signed = call.has_flag(engine_state, stack, "signed")?;
let number_bytes: Option<Spanned<usize>> = let number_bytes: Option<Spanned<usize>> =
call.get_flag(engine_state, stack, "number-bytes")?; call.get_flag(engine_state, stack, "number-bytes")?;
@ -121,6 +120,8 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
number_size, number_size,
bits, bits,
} = *args; } = *args;
let bits_span = bits.span;
let bits = bits.item;
match input { match input {
Value::Int { val, .. } => { Value::Int { val, .. } => {
@ -129,6 +130,19 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
let bits = bits as u32; let bits = bits as u32;
let input_num_type = get_input_num_type(val, signed, number_size); let input_num_type = get_input_num_type(val, signed, number_size);
if !input_num_type.is_permitted_bit_shift(bits) {
return Value::error(
ShellError::IncorrectValue {
msg: format!(
"Trying to shift by more than the available bits (permitted < {})",
input_num_type.num_bits()
),
val_span: bits_span,
call_span: span,
},
span,
);
}
let int = match input_num_type { let int = match input_num_type {
One => ((val as u8) >> bits) as i64, One => ((val as u8) >> bits) as i64,
Two => ((val as u16) >> bits) as i64, Two => ((val as u16) >> bits) as i64,
@ -147,21 +161,27 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
let bit_shift = bits % 8; let bit_shift = bits % 8;
let len = val.len(); let len = val.len();
use itertools::Position::*; // This check is done for symmetry with the int case and the previous
let bytes = iter::repeat(0) // implementation would overflow byte indices leading to unexpected output
.take(byte_shift) // lengths
.chain( if bits > len * 8 {
val.iter() return Value::error(
.copied() ShellError::IncorrectValue {
.circular_tuple_windows::<(u8, u8)>() msg: format!(
.with_position() "Trying to shift by more than the available bits ({})",
.map(|(pos, (lhs, rhs))| match pos { len * 8
First | Only => lhs >> bit_shift, ),
_ => (lhs >> bit_shift) | (rhs << bit_shift), val_span: bits_span,
}) call_span: span,
.take(len - byte_shift), },
) span,
.collect::<Vec<u8>>(); );
}
let bytes = if bit_shift == 0 {
shift_bytes_right(val, byte_shift)
} else {
shift_bytes_and_bits_right(val, byte_shift, bit_shift)
};
Value::binary(bytes, span) Value::binary(bytes, span)
} }
@ -178,6 +198,35 @@ fn action(input: &Value, args: &Arguments, span: Span) -> Value {
), ),
} }
} }
fn shift_bytes_right(data: &[u8], byte_shift: usize) -> Vec<u8> {
let len = data.len();
let mut output = vec![0; len];
output[byte_shift..].copy_from_slice(&data[..len - byte_shift]);
output
}
fn shift_bytes_and_bits_right(data: &[u8], byte_shift: usize, bit_shift: usize) -> Vec<u8> {
debug_assert!(
bit_shift > 0 && bit_shift < 8,
"bit_shift should be in the range (0, 8)"
);
let len = data.len();
let mut output = vec![0; len];
for i in byte_shift..len {
let shifted_bits = data[i - byte_shift] >> bit_shift;
let carried_bits = if i > byte_shift {
data[i - byte_shift - 1] << (8 - bit_shift)
} else {
0
};
let shifted_byte = shifted_bits | carried_bits;
output[i] = shifted_byte;
}
output
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {

View file

@ -3317,7 +3317,18 @@ impl Value {
pub fn bit_shl(&self, op: Span, rhs: &Value, span: Span) -> Result<Value, ShellError> { pub fn bit_shl(&self, op: Span, rhs: &Value, span: Span) -> Result<Value, ShellError> {
match (self, rhs) { match (self, rhs) {
(Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => { (Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => {
Ok(Value::int(*lhs << rhs, span)) // Currently we disallow negative operands like Rust's `Shl`
// Cheap guarding with TryInto<u32>
if let Some(val) = (*rhs).try_into().ok().and_then(|rhs| lhs.checked_shl(rhs)) {
Ok(Value::int(val, span))
} else {
Err(ShellError::OperatorOverflow {
msg: "right operand to bit-shl exceeds available bits in underlying data"
.into(),
span,
help: format!("Limit operand to 0 <= rhs < {}", i64::BITS),
})
}
} }
(Value::Custom { val: lhs, .. }, rhs) => { (Value::Custom { val: lhs, .. }, rhs) => {
lhs.operation(span, Operator::Bits(Bits::ShiftLeft), op, rhs) lhs.operation(span, Operator::Bits(Bits::ShiftLeft), op, rhs)
@ -3335,7 +3346,18 @@ impl Value {
pub fn bit_shr(&self, op: Span, rhs: &Value, span: Span) -> Result<Value, ShellError> { pub fn bit_shr(&self, op: Span, rhs: &Value, span: Span) -> Result<Value, ShellError> {
match (self, rhs) { match (self, rhs) {
(Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => { (Value::Int { val: lhs, .. }, Value::Int { val: rhs, .. }) => {
Ok(Value::int(*lhs >> rhs, span)) // Currently we disallow negative operands like Rust's `Shr`
// Cheap guarding with TryInto<u32>
if let Some(val) = (*rhs).try_into().ok().and_then(|rhs| lhs.checked_shr(rhs)) {
Ok(Value::int(val, span))
} else {
Err(ShellError::OperatorOverflow {
msg: "right operand to bit-shr exceeds available bits in underlying data"
.into(),
span,
help: format!("Limit operand to 0 <= rhs < {}", i64::BITS),
})
}
} }
(Value::Custom { val: lhs, .. }, rhs) => { (Value::Custom { val: lhs, .. }, rhs) => {
lhs.operation(span, Operator::Bits(Bits::ShiftRight), op, rhs) lhs.operation(span, Operator::Bits(Bits::ShiftRight), op, rhs)

View file

@ -1,4 +1,4 @@
use crate::repl::tests::{run_test, TestResult}; use crate::repl::tests::{fail_test, run_test, TestResult};
#[test] #[test]
fn bits_and() -> TestResult { fn bits_and() -> TestResult {
@ -56,6 +56,33 @@ fn bits_shift_left() -> TestResult {
run_test("2 | bits shl 3", "16") run_test("2 | bits shl 3", "16")
} }
#[test]
fn bits_shift_left_negative_operand() -> TestResult {
fail_test("8 | bits shl -2", "positive value")
}
#[test]
fn bits_shift_left_exceeding1() -> TestResult {
// We have no type accepting more than 64 bits so guaranteed fail
fail_test("8 | bits shl 65", "more than the available bits")
}
#[test]
fn bits_shift_left_exceeding2() -> TestResult {
// Explicitly specifying 2 bytes, but 16 is already the max
fail_test(
"8 | bits shl --number-bytes 2 16",
"more than the available bits",
)
}
#[test]
fn bits_shift_left_exceeding3() -> TestResult {
// This is purely down to the current autodetect feature limiting to the smallest integer
// type thus assuming a u8
fail_test("8 | bits shl 9", "more than the available bits")
}
#[test] #[test]
fn bits_shift_left_negative() -> TestResult { fn bits_shift_left_negative() -> TestResult {
run_test("-3 | bits shl 5", "-96") run_test("-3 | bits shl 5", "-96")
@ -69,11 +96,79 @@ fn bits_shift_left_list() -> TestResult {
) )
} }
#[test]
fn bits_shift_left_binary1() -> TestResult {
run_test(
"0x[01 30 80] | bits shl 3 | into bits",
"00001001 10000100 00000000",
)
}
#[test]
fn bits_shift_left_binary2() -> TestResult {
// Whole byte case
run_test(
"0x[01 30 80] | bits shl 8 | into bits",
"00110000 10000000 00000000",
)
}
#[test]
fn bits_shift_left_binary3() -> TestResult {
// Compared to the int case this is made inclusive of the bit count
run_test(
"0x[01 30 80] | bits shl 24 | into bits",
"00000000 00000000 00000000",
)
}
#[test]
fn bits_shift_left_binary4() -> TestResult {
// Shifting by both bytes and bits
run_test(
"0x[01 30 80] | bits shl 15 | into bits",
"01000000 00000000 00000000",
)
}
#[test]
fn bits_shift_left_binary_exceeding() -> TestResult {
// Compared to the int case this is made inclusive of the bit count
fail_test("0x[01 30] | bits shl 17 | into bits", "")
}
#[test] #[test]
fn bits_shift_right() -> TestResult { fn bits_shift_right() -> TestResult {
run_test("8 | bits shr 2", "2") run_test("8 | bits shr 2", "2")
} }
#[test]
fn bits_shift_right_negative_operand() -> TestResult {
fail_test("8 | bits shr -2", "positive value")
}
#[test]
fn bits_shift_right_exceeding1() -> TestResult {
// We have no type accepting more than 64 bits so guaranteed fail
fail_test("8 | bits shr 65", "more than the available bits")
}
#[test]
fn bits_shift_right_exceeding2() -> TestResult {
// Explicitly specifying 2 bytes, but 16 is already the max
fail_test(
"8 | bits shr --number-bytes 2 16",
"more than the available bits",
)
}
#[test]
fn bits_shift_right_exceeding3() -> TestResult {
// This is purely down to the current autodetect feature limiting to the smallest integer
// type thus assuming a u8
fail_test("8 | bits shr 9", "more than the available bits")
}
#[test] #[test]
fn bits_shift_right_negative() -> TestResult { fn bits_shift_right_negative() -> TestResult {
run_test("-32 | bits shr 2", "-8") run_test("-32 | bits shr 2", "-8")
@ -87,6 +182,47 @@ fn bits_shift_right_list() -> TestResult {
) )
} }
#[test]
fn bits_shift_right_binary1() -> TestResult {
run_test(
"0x[01 30 80] | bits shr 3 | into bits",
"00000000 00100110 00010000",
)
}
#[test]
fn bits_shift_right_binary2() -> TestResult {
// Whole byte case
run_test(
"0x[01 30 80] | bits shr 8 | into bits",
"00000000 00000001 00110000",
)
}
#[test]
fn bits_shift_right_binary3() -> TestResult {
// Compared to the int case this is made inclusive of the bit count
run_test(
"0x[01 30 80] | bits shr 24 | into bits",
"00000000 00000000 00000000",
)
}
#[test]
fn bits_shift_right_binary4() -> TestResult {
// Shifting by both bytes and bits
run_test(
"0x[01 30 80] | bits shr 15 | into bits",
"00000000 00000000 00000010",
)
}
#[test]
fn bits_shift_right_binary_exceeding() -> TestResult {
// Compared to the int case this is made inclusive of the bit count
fail_test("0x[01 30] | bits shr 17 | into bits", "available bits (16)")
}
#[test] #[test]
fn bits_rotate_left() -> TestResult { fn bits_rotate_left() -> TestResult {
run_test("2 | bits rol 3", "16") run_test("2 | bits rol 3", "16")

View file

@ -35,6 +35,31 @@ fn bit_shl() -> TestResult {
run_test("5 bit-shl 1", "10") run_test("5 bit-shl 1", "10")
} }
#[test]
fn bit_shr_overflow() -> TestResult {
fail_test("16 bit-shr 10000", "exceeds available bits")
}
#[test]
fn bit_shl_overflow() -> TestResult {
fail_test("5 bit-shl 10000000", "exceeds available bits")
}
#[test]
fn bit_shl_neg_operand() -> TestResult {
// This would overflow the `u32` in the right hand side to 2
fail_test(
"9 bit-shl -9_223_372_036_854_775_806",
"exceeds available bits",
)
}
#[test]
fn bit_shr_neg_operand() -> TestResult {
// This would overflow the `u32` in the right hand side
fail_test("9 bit-shr -2", "exceeds available bits")
}
#[test] #[test]
fn bit_shl_add() -> TestResult { fn bit_shl_add() -> TestResult {
run_test("2 bit-shl 1 + 2", "16") run_test("2 bit-shl 1 + 2", "16")