Implement rough symbol interning infra

This commit is contained in:
Lukas Wirth 2024-07-12 12:26:23 +02:00
parent ffbc5ad993
commit 6275eb140e
6 changed files with 543 additions and 1 deletions

View file

@ -14,6 +14,8 @@ extend-ignore-re = [
"\\w*\\.{3,4}\\w*",
'"flate2"',
"raison d'être",
"inout",
"optin"
]
[default.extend-words]

7
Cargo.lock generated
View file

@ -835,6 +835,7 @@ dependencies = [
"dashmap",
"hashbrown",
"rustc-hash",
"sptr",
"triomphe",
]
@ -1885,6 +1886,12 @@ dependencies = [
"vfs",
]
[[package]]
name = "sptr"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a"
[[package]]
name = "stable_deref_trait"
version = "1.2.0"

View file

@ -18,6 +18,7 @@ dashmap.workspace = true
hashbrown.workspace = true
rustc-hash.workspace = true
triomphe.workspace = true
sptr = "0.3.2"
[lints]
workspace = true
workspace = true

View file

@ -20,6 +20,9 @@ type Guard<T> = dashmap::RwLockWriteGuard<
HashMap<Arc<T>, SharedValue<()>, BuildHasherDefault<FxHasher>>,
>;
mod symbol;
pub use self::symbol::{symbols, Symbol};
pub struct Interned<T: Internable + ?Sized> {
arc: Arc<T>,
}

293
crates/intern/src/symbol.rs Normal file
View file

@ -0,0 +1,293 @@
//! Attempt at flexible symbol interning, allowing to intern and free strings at runtime while also
//! supporting
use std::{
borrow::Borrow,
fmt,
hash::{BuildHasherDefault, Hash, Hasher},
mem,
ptr::NonNull,
sync::OnceLock,
};
use dashmap::{DashMap, SharedValue};
use hashbrown::{hash_map::RawEntryMut, HashMap};
use rustc_hash::FxHasher;
use sptr::Strict;
use triomphe::Arc;
pub mod symbols;
// some asserts for layout compatibility
const _: () = assert!(std::mem::size_of::<Box<str>>() == std::mem::size_of::<&str>());
const _: () = assert!(std::mem::align_of::<Box<str>>() == std::mem::align_of::<&str>());
const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>());
const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>());
/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or
/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag
/// in the LSB of the alignment niche.
#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
struct TaggedArcPtr {
packed: NonNull<*const str>,
}
unsafe impl Send for TaggedArcPtr {}
unsafe impl Sync for TaggedArcPtr {}
impl TaggedArcPtr {
const BOOL_BITS: usize = true as usize;
const fn non_arc(r: &&str) -> Self {
Self {
// SAFETY: The pointer is non-null as it is derived from a reference
// Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
// packing stuff requires reading out the pointer to an integer which is not supported
// in const contexts, so here we make use of the fact that for the non-arc version the
// tag is false (0) and thus does not need touching the actual pointer value.ext)
packed: unsafe {
NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut())
},
}
}
fn arc(arc: Arc<Box<str>>) -> Self {
Self {
packed: Self::pack_arc(
// Safety: `Arc::into_raw` always returns a non null pointer
unsafe { NonNull::new_unchecked(Arc::into_raw(arc).cast_mut().cast()) },
),
}
}
/// Retrieves the tag.
#[inline]
pub(crate) fn try_as_arc_owned(self) -> Option<Arc<Box<str>>> {
// Unpack the tag from the alignment niche
let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS;
if tag != 0 {
// Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc`
Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>()) })
} else {
None
}
}
#[inline]
const fn pack_arc(ptr: NonNull<*const str>) -> NonNull<*const str> {
let packed_tag = true as usize;
// can't use this strict provenance stuff here due to trait methods not being const
// unsafe {
// // Safety: The pointer is derived from a non-null
// NonNull::new_unchecked(Strict::map_addr(ptr.as_ptr(), |addr| {
// // Safety:
// // - The pointer is `NonNull` => it's address is `NonZero<usize>`
// // - `P::BITS` least significant bits are always zero (`Pointer` contract)
// // - `T::BITS <= P::BITS` (from `Self::ASSERTION`)
// //
// // Thus `addr >> T::BITS` is guaranteed to be non-zero.
// //
// // `{non_zero} | packed_tag` can't make the value zero.
// (addr >> Self::BOOL_BITS) | packed_tag
// }))
// }
// so what follows is roughly what the above looks like but inlined
let self_addr = unsafe { core::mem::transmute::<*const _, usize>(ptr.as_ptr()) };
let addr = self_addr | packed_tag;
let dest_addr = addr as isize;
let offset = dest_addr.wrapping_sub(self_addr as isize);
// SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes
unsafe { NonNull::new_unchecked(ptr.as_ptr().cast::<u8>().wrapping_offset(offset).cast()) }
}
#[inline]
pub(crate) fn pointer(self) -> NonNull<*const str> {
// SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes
unsafe {
NonNull::new_unchecked(Strict::map_addr(self.packed.as_ptr(), |addr| {
addr & !Self::BOOL_BITS
}))
}
}
#[inline]
pub(crate) fn as_str(&self) -> &str {
// SAFETY: We always point to a pointer to a str no matter what variant is active
unsafe { *self.pointer().as_ptr().cast::<&str>() }
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct Symbol {
repr: TaggedArcPtr,
}
const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>());
const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>());
static MAP: OnceLock<DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>>> = OnceLock::new();
impl Symbol {
pub fn intern(s: &str) -> Self {
let (mut shard, hash) = Self::select_shard(s);
// Atomically,
// - check if `obj` is already in the map
// - if so, copy out its entry, conditionally bumping the backing Arc and return it
// - if not, put it into a box and then into an Arc, insert it, bump the ref-count and return the copy
// This needs to be atomic (locking the shard) to avoid races with other thread, which could
// insert the same object between us looking it up and inserting it.
match shard.raw_entry_mut().from_key_hashed_nocheck(hash, s) {
RawEntryMut::Occupied(occ) => Self { repr: increase_arc_refcount(occ.key().0) },
RawEntryMut::Vacant(vac) => Self {
repr: increase_arc_refcount(
vac.insert_hashed_nocheck(
hash,
SymbolProxy(TaggedArcPtr::arc(Arc::new(Box::<str>::from(s)))),
SharedValue::new(()),
)
.0
.0,
),
},
}
}
pub fn as_str(&self) -> &str {
self.repr.as_str()
}
#[inline]
fn select_shard(
s: &str,
) -> (
dashmap::RwLockWriteGuard<
'static,
HashMap<SymbolProxy, SharedValue<()>, BuildHasherDefault<FxHasher>>,
>,
u64,
) {
let storage = MAP.get_or_init(symbols::prefill);
let hash = {
let mut hasher = std::hash::BuildHasher::build_hasher(storage.hasher());
s.hash(&mut hasher);
hasher.finish()
};
let shard_idx = storage.determine_shard(hash as usize);
let shard = &storage.shards()[shard_idx];
(shard.write(), hash)
}
#[cold]
fn drop_slow(arc: &Arc<Box<str>>) {
let (mut shard, hash) = Self::select_shard(arc);
if Arc::count(arc) != 2 {
// Another thread has interned another copy
return;
}
match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) {
RawEntryMut::Occupied(occ) => occ.remove_entry(),
RawEntryMut::Vacant(_) => unreachable!(),
}
.0
.0
.try_as_arc_owned()
.unwrap();
// Shrink the backing storage if the shard is less than 50% occupied.
if shard.len() * 2 < shard.capacity() {
shard.shrink_to_fit();
}
}
}
impl Drop for Symbol {
#[inline]
fn drop(&mut self) {
let Some(arc) = self.repr.try_as_arc_owned() else {
return;
};
// When the last `Ref` is dropped, remove the object from the global map.
if Arc::count(&arc) == 2 {
// Only `self` and the global map point to the object.
Self::drop_slow(&arc);
}
// decrement the ref count
drop(arc);
}
}
fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr {
let Some(arc) = repr.try_as_arc_owned() else {
return repr;
};
// increase the ref count
mem::forget(arc.clone());
mem::forget(arc);
repr
}
impl fmt::Display for Symbol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_str().fmt(f)
}
}
// only exists so we can use `from_key_hashed_nocheck` with a &str
#[derive(Debug, PartialEq, Eq)]
struct SymbolProxy(TaggedArcPtr);
impl Hash for SymbolProxy {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.as_str().hash(state);
}
}
impl Borrow<str> for SymbolProxy {
fn borrow(&self) -> &str {
self.0.as_str()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn smoke_test() {
Symbol::intern("isize");
let base_len = MAP.get().unwrap().len();
let hello = Symbol::intern("hello");
let world = Symbol::intern("world");
let bang = Symbol::intern("!");
let q = Symbol::intern("?");
assert_eq!(MAP.get().unwrap().len(), base_len + 4);
let bang2 = Symbol::intern("!");
assert_eq!(MAP.get().unwrap().len(), base_len + 4);
drop(bang2);
assert_eq!(MAP.get().unwrap().len(), base_len + 4);
drop(q);
assert_eq!(MAP.get().unwrap().len(), base_len + 3);
let default = Symbol::intern("default");
assert_eq!(MAP.get().unwrap().len(), base_len + 3);
assert_eq!(
"hello default world!",
format!("{} {} {}{}", hello.as_str(), default.as_str(), world.as_str(), bang.as_str())
);
drop(default);
assert_eq!(
"hello world!",
format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str())
);
drop(hello);
drop(world);
drop(bang);
assert_eq!(MAP.get().unwrap().len(), base_len);
}
}

View file

@ -0,0 +1,236 @@
#![allow(non_upper_case_globals)]
use std::hash::{BuildHasherDefault, Hash as _, Hasher as _};
use dashmap::{DashMap, SharedValue};
use rustc_hash::FxHasher;
use crate::{
symbol::{SymbolProxy, TaggedArcPtr},
Symbol,
};
macro_rules! define_symbols {
($($name:ident),* $(,)?) => {
$(
pub const $name: Symbol = Symbol { repr: TaggedArcPtr::non_arc(&stringify!($name)) };
)*
pub(super) fn prefill() -> DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>> {
let mut dashmap_ = <DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>>>::with_hasher(BuildHasherDefault::default());
let hash_thing_ = |hasher_: &BuildHasherDefault<FxHasher>, it_: &SymbolProxy| {
let mut hasher_ = std::hash::BuildHasher::build_hasher(hasher_);
it_.hash(&mut hasher_);
hasher_.finish()
};
{
$(
let proxy_ = SymbolProxy($name.repr);
let hash_ = hash_thing_(dashmap_.hasher(), &proxy_);
let shard_idx_ = dashmap_.determine_shard(hash_ as usize);
dashmap_.shards_mut()[shard_idx_].get_mut().raw_entry_mut().from_hash(hash_, |k| k == &proxy_).insert(proxy_, SharedValue::new(()));
)*
}
dashmap_
}
};
}
define_symbols! {
add_assign,
add,
alloc,
as_str,
asm,
assert,
bench,
bitand_assign,
bitand,
bitor_assign,
bitor,
bitxor_assign,
bitxor,
bool,
Box,
boxed,
branch,
call_mut,
call_once,
call,
Center,
cfg_accessible,
cfg_attr,
cfg_eval,
cfg,
char,
Clone,
column,
compile_error,
concat_bytes,
concat_idents,
concat,
const_format_args,
Copy,
core_panic,
core,
crate_type,
Debug,
default,
Default,
deref_mut,
deref,
derive_const,
derive,
div_assign,
div,
doc,
drop,
env,
eq,
Eq,
f128,
f16,
f32,
f64,
feature,
file,
filter_map,
fmt,
fn_mut,
fn_once,
format_args_nl,
format_args,
format,
from_usize,
future_trait,
future,
Future,
ge,
global_allocator,
global_asm,
gt,
Hash,
i128,
i16,
i32,
i64,
i8,
Implied,
include_bytes,
include_str,
include,
index_mut,
index,
Index,
into_future,
IntoFuture,
IntoIter,
IntoIterator,
is_empty,
Is,
isize,
Item,
iter_mut,
iter,
Iterator,
le,
Left,
len,
line,
llvm_asm,
log_syntax,
lt,
macro_rules,
module_path,
mul_assign,
mul,
ne,
neg,
Neg,
new_binary,
new_debug,
new_display,
new_lower_exp,
new_lower_hex,
new_octal,
new_pointer,
new_upper_exp,
new_upper_hex,
new_v1_formatted,
new,
next,
no_core,
no_std,
none,
None,
not,
Not,
Ok,
ops,
option_env,
option,
Option,
Ord,
Output,
owned_box,
panic_2015,
panic_2021,
Param,
partial_ord,
PartialEq,
PartialOrd,
pieces,
poll,
prelude,
quote,
r#fn,
Range,
RangeFrom,
RangeFull,
RangeInclusive,
RangeTo,
RangeToInclusive,
recursion_limit,
register_attr,
register_tool,
rem_assign,
rem,
result,
Result,
Right,
rust_2015,
rust_2018,
rust_2021,
rust_2024,
shl_assign,
shl,
shr_assign,
shr,
std_panic,
std,
str,
string,
String,
stringify,
sub_assign,
sub,
Target,
test_case,
test,
trace_macros,
Try,
u128,
u16,
u32,
u64,
u8,
Unknown,
unreachable_2015,
unreachable_2021,
unreachable,
unsafe_cell,
usize,
v1,
va_list
}