shred: improve performance by switching to StdRng

Also contains some cleanup
This commit is contained in:
Terts Diepraam 2023-03-10 14:17:48 +01:00
parent 73d5c4474b
commit 4e3d50064e

View file

@ -10,7 +10,7 @@
use clap::{crate_version, Arg, ArgAction, Command}; use clap::{crate_version, Arg, ArgAction, Command};
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
use rand::Rng; use rand::{rngs::StdRng, Rng, SeedableRng};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::fs; use std::fs;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
@ -55,10 +55,10 @@ const PATTERNS: [&[u8]; 22] = [
b"\xEE", b"\xEE",
]; ];
#[derive(Clone, Copy)] #[derive(Clone)]
enum PassType<'a> { enum PassType<'a> {
Pattern(&'a [u8]), Pattern(&'a [u8]),
Random, Random(Box<StdRng>),
} }
// Used to generate all possible filenames of a certain length using NAME_CHARSET as an alphabet // Used to generate all possible filenames of a certain length using NAME_CHARSET as an alphabet
@ -116,60 +116,43 @@ impl Iterator for FilenameGenerator {
// Used to generate blocks of bytes of size <= BLOCK_SIZE based on either a give pattern // Used to generate blocks of bytes of size <= BLOCK_SIZE based on either a give pattern
// or randomness // or randomness
struct BytesGenerator<'a> { struct BytesGenerator {
total_bytes: u64, total_bytes: u64,
bytes_generated: Cell<u64>, bytes_generated: u64,
block_size: usize, block_size: usize,
exact: bool, // if false, every block's size is block_size exact: bool, // if false, every block's size is block_size
gen_type: PassType<'a>,
rng: Option<RefCell<rand::rngs::ThreadRng>>,
bytes: [u8; BLOCK_SIZE], bytes: [u8; BLOCK_SIZE],
} }
impl<'a> BytesGenerator<'a> { impl BytesGenerator {
fn new(total_bytes: u64, gen_type: PassType<'a>, exact: bool) -> BytesGenerator { fn new(exact: bool) -> Self {
let rng = match gen_type {
PassType::Random => Some(RefCell::new(rand::thread_rng())),
PassType::Pattern(_) => None,
};
let bytes = [0; BLOCK_SIZE]; let bytes = [0; BLOCK_SIZE];
BytesGenerator { Self {
total_bytes, total_bytes: 0,
bytes_generated: Cell::new(0u64), bytes_generated: 0,
block_size: BLOCK_SIZE, block_size: BLOCK_SIZE,
exact, exact,
gen_type,
rng,
bytes, bytes,
} }
} }
pub fn reset(&mut self, total_bytes: u64, gen_type: PassType<'a>) { pub fn reset(&mut self, total_bytes: u64) {
if let PassType::Random = gen_type {
if self.rng.is_none() {
self.rng = Some(RefCell::new(rand::thread_rng()));
}
}
self.total_bytes = total_bytes; self.total_bytes = total_bytes;
self.gen_type = gen_type; self.bytes_generated = 0;
self.bytes_generated.set(0);
} }
pub fn next(&mut self) -> Option<&[u8]> { pub fn next_pass(&mut self, pass: &mut PassType) -> Option<&[u8]> {
// We go over the total_bytes limit when !self.exact and total_bytes isn't a multiple // We go over the total_bytes limit when !self.exact and total_bytes isn't a multiple
// of self.block_size // of self.block_size
if self.bytes_generated.get() >= self.total_bytes { if self.bytes_generated >= self.total_bytes {
return None; return None;
} }
let this_block_size = if !self.exact { let this_block_size = if !self.exact {
self.block_size self.block_size
} else { } else {
let bytes_left = self.total_bytes - self.bytes_generated.get(); let bytes_left = self.total_bytes - self.bytes_generated;
if bytes_left >= self.block_size as u64 { if bytes_left >= self.block_size as u64 {
self.block_size self.block_size
} else { } else {
@ -179,16 +162,15 @@ impl<'a> BytesGenerator<'a> {
let bytes = &mut self.bytes[..this_block_size]; let bytes = &mut self.bytes[..this_block_size];
match self.gen_type { match pass {
PassType::Random => { PassType::Random(rng) => {
let mut rng = self.rng.as_ref().unwrap().borrow_mut();
rng.fill(bytes); rng.fill(bytes);
} }
PassType::Pattern(pattern) => { PassType::Pattern(pattern) => {
let skip = if self.bytes_generated.get() == 0 { let skip = if self.bytes_generated == 0 {
0 0
} else { } else {
(pattern.len() as u64 % self.bytes_generated.get()) as usize (pattern.len() as u64 % self.bytes_generated) as usize
}; };
// Copy the pattern in chunks rather than simply one byte at a time // Copy the pattern in chunks rather than simply one byte at a time
@ -205,8 +187,7 @@ impl<'a> BytesGenerator<'a> {
} }
}; };
let new_bytes_generated = self.bytes_generated.get() + this_block_size as u64; self.bytes_generated += this_block_size as u64;
self.bytes_generated.set(new_bytes_generated);
Some(bytes) Some(bytes)
} }
@ -421,13 +402,13 @@ fn get_size(size_str_opt: Option<String>) -> Option<u64> {
Some(coefficient * unit) Some(coefficient * unit)
} }
fn pass_name(pass_type: PassType) -> String { fn pass_name(pass_type: &PassType) -> String {
match pass_type { match pass_type {
PassType::Random => String::from("random"), PassType::Random(_) => String::from("random"),
PassType::Pattern(bytes) => { PassType::Pattern(bytes) => {
let mut s: String = String::new(); let mut s: String = String::new();
while s.len() < 6 { while s.len() < 6 {
for b in bytes { for b in *bytes {
let readable: String = format!("{b:x}"); let readable: String = format!("{b:x}");
s.push_str(&readable); s.push_str(&readable);
} }
@ -484,16 +465,15 @@ fn wipe_file(
} }
// Fill up our pass sequence // Fill up our pass sequence
let mut pass_sequence: Vec<PassType> = Vec::new(); let mut pass_sequence = Vec::new();
if n_passes <= 3 { if n_passes <= 3 {
// Only random passes if n_passes <= 3 // Only random passes if n_passes <= 3
for _ in 0..n_passes { for _ in 0..n_passes {
pass_sequence.push(PassType::Random); pass_sequence.push(PassType::Random(Box::new(StdRng::from_entropy())));
} }
} } else {
// First fill it with Patterns, shuffle it, then evenly distribute Random // First fill it with Patterns, shuffle it, then evenly distribute Random
else {
let n_full_arrays = n_passes / PATTERNS.len(); // How many times can we go through all the patterns? let n_full_arrays = n_passes / PATTERNS.len(); // How many times can we go through all the patterns?
let remainder = n_passes % PATTERNS.len(); // How many do we get through on our last time through? let remainder = n_passes % PATTERNS.len(); // How many do we get through on our last time through?
@ -511,7 +491,8 @@ fn wipe_file(
let n_random = 3 + n_passes / 10; // Minimum 3 random passes; ratio of 10 after let n_random = 3 + n_passes / 10; // Minimum 3 random passes; ratio of 10 after
// Evenly space random passes; ensures one at the beginning and end // Evenly space random passes; ensures one at the beginning and end
for i in 0..n_random { for i in 0..n_random {
pass_sequence[i * (n_passes - 1) / (n_random - 1)] = PassType::Random; pass_sequence[i * (n_passes - 1) / (n_random - 1)] =
PassType::Random(Box::new(StdRng::from_entropy()));
} }
} }
@ -520,46 +501,49 @@ fn wipe_file(
pass_sequence.push(PassType::Pattern(b"\x00")); pass_sequence.push(PassType::Pattern(b"\x00"));
} }
{ let total_passes: usize = pass_sequence.len();
let total_passes: usize = pass_sequence.len(); let mut file: File = OpenOptions::new()
let mut file: File = OpenOptions::new() .write(true)
.write(true) .truncate(false)
.truncate(false) .open(path)
.open(path) .map_err_context(|| format!("{}: failed to open for writing", path.maybe_quote()))?;
.map_err_context(|| format!("{}: failed to open for writing", path.maybe_quote()))?;
// NOTE: it does not really matter what we set for total_bytes and gen_type here, so just // NOTE: it does not really matter what we set for total_bytes and gen_type here, so just
// use bogus values // use bogus values
let mut generator = BytesGenerator::new(0, PassType::Pattern(&[]), exact); let mut generator = BytesGenerator::new(exact);
for (i, pass_type) in pass_sequence.iter().enumerate() { let size = match size {
if verbose { Some(size) => size,
let pass_name: String = pass_name(*pass_type); None => get_file_size(path)?,
if total_passes.to_string().len() == 1 { };
println!(
"{}: {}: pass {}/{} ({})... ", for (i, pass_type) in pass_sequence.into_iter().enumerate() {
util_name(), if verbose {
path.maybe_quote(), let pass_name: String = pass_name(&pass_type);
i + 1, if total_passes.to_string().len() == 1 {
total_passes, println!(
pass_name "{}: {}: pass {}/{} ({})... ",
); util_name(),
} else { path.maybe_quote(),
println!( i + 1,
"{}: {}: pass {:2.0}/{:2.0} ({})... ", total_passes,
util_name(), pass_name
path.maybe_quote(), );
i + 1, } else {
total_passes, println!(
pass_name "{}: {}: pass {:2.0}/{:2.0} ({})... ",
); util_name(),
} path.maybe_quote(),
i + 1,
total_passes,
pass_name
);
} }
// size is an optional argument for exactly how many bytes we want to shred
show_if_err!(do_pass(&mut file, path, &mut generator, *pass_type, size)
.map_err_context(|| format!("{}: File write pass failed", path.maybe_quote())));
// Ignore failed writes; just keep trying
} }
// size is an optional argument for exactly how many bytes we want to shred
show_if_err!(do_pass(&mut file, &mut generator, pass_type, size)
.map_err_context(|| format!("{}: File write pass failed", path.maybe_quote())));
// Ignore failed writes; just keep trying
} }
if remove { if remove {
@ -569,21 +553,17 @@ fn wipe_file(
Ok(()) Ok(())
} }
fn do_pass<'a>( fn do_pass(
file: &mut File, file: &mut File,
path: &Path, generator: &mut BytesGenerator,
generator: &mut BytesGenerator<'a>, mut pass_type: PassType,
generator_type: PassType<'a>, file_size: u64,
given_file_size: Option<u64>,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
file.rewind()?; file.rewind()?;
// Use the given size or the whole file if not specified generator.reset(file_size);
let size: u64 = given_file_size.unwrap_or(get_file_size(path)?);
generator.reset(size, generator_type); while let Some(block) = generator.next_pass(&mut pass_type) {
while let Some(block) = generator.next() {
file.write_all(block)?; file.write_all(block)?;
} }