From dd626e78c7e0b929e6b7ebef775aa41220b18101 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 12 Jul 2024 17:11:12 +0200 Subject: [PATCH] Fix cloning Symbols not increasing their ref count --- crates/intern/src/symbol.rs | 90 ++++++++++++++++++++--------- crates/intern/src/symbol/symbols.rs | 1 + 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/crates/intern/src/symbol.rs b/crates/intern/src/symbol.rs index 1517f516fc..9f7a788d00 100644 --- a/crates/intern/src/symbol.rs +++ b/crates/intern/src/symbol.rs @@ -5,7 +5,7 @@ use std::{ borrow::Borrow, fmt, hash::{BuildHasherDefault, Hash, Hasher}, - mem, + mem::{self, ManuallyDrop}, ptr::NonNull, sync::OnceLock, }; @@ -25,6 +25,15 @@ const _: () = assert!(std::mem::align_of::>() == std::mem::align_of::<& const _: () = assert!(std::mem::size_of::>>() == std::mem::size_of::<&&str>()); const _: () = assert!(std::mem::align_of::>>() == std::mem::align_of::<&&str>()); +const _: () = + assert!(std::mem::size_of::<*const *const str>() == std::mem::size_of::()); +const _: () = + assert!(std::mem::align_of::<*const *const str>() == std::mem::align_of::()); + +const _: () = assert!(std::mem::size_of::>>() == std::mem::size_of::()); +const _: () = + assert!(std::mem::align_of::>>() == std::mem::align_of::()); + /// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or /// `Arc>` but its size is that of a thin pointer. The active variant is encoded as a tag /// in the LSB of the alignment niche. @@ -40,19 +49,24 @@ impl TaggedArcPtr { const BOOL_BITS: usize = true as usize; const fn non_arc(r: &'static &'static str) -> Self { - Self { - // SAFETY: The pointer is non-null as it is derived from a reference - // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the - // packing stuff requires reading out the pointer to an integer which is not supported - // in const contexts, so here we make use of the fact that for the non-arc version the - // tag is false (0) and thus does not need touching the actual pointer value.ext) - packed: unsafe { - NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) - }, - } + assert!( + mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS + ); + // SAFETY: The pointer is non-null as it is derived from a reference + // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the + // packing stuff requires reading out the pointer to an integer which is not supported + // in const contexts, so here we make use of the fact that for the non-arc version the + // tag is false (0) and thus does not need touching the actual pointer value.ext) + + let packed = + unsafe { NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut()) }; + Self { packed } } fn arc(arc: Arc>) -> Self { + assert!( + mem::align_of::<&'static &'static str>().trailing_zeros() as usize > Self::BOOL_BITS + ); Self { packed: Self::pack_arc( // Safety: `Arc::into_raw` always returns a non null pointer @@ -63,12 +77,14 @@ impl TaggedArcPtr { /// Retrieves the tag. #[inline] - pub(crate) fn try_as_arc_owned(self) -> Option>> { + pub(crate) fn try_as_arc_owned(self) -> Option>>> { // Unpack the tag from the alignment niche let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS; if tag != 0 { // Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc` - Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::>()) }) + Some(ManuallyDrop::new(unsafe { + Arc::from_raw(self.pointer().as_ptr().cast::>()) + })) } else { None } @@ -122,10 +138,11 @@ impl TaggedArcPtr { } } -#[derive(PartialEq, Eq, Hash, Clone, Debug)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Symbol { repr: TaggedArcPtr, } + const _: () = assert!(std::mem::size_of::() == std::mem::size_of::>()); const _: () = assert!(std::mem::align_of::() == std::mem::align_of::>()); @@ -185,19 +202,27 @@ impl Symbol { fn drop_slow(arc: &Arc>) { let (mut shard, hash) = Self::select_shard(arc); - if Arc::count(arc) != 2 { - // Another thread has interned another copy - return; + match Arc::count(arc) { + 0 => unreachable!(), + 1 => unreachable!(), + 2 => (), + _ => { + // Another thread has interned another copy + return; + } } - match shard.raw_entry_mut().from_key_hashed_nocheck::(hash, arc.as_ref()) { - RawEntryMut::Occupied(occ) => occ.remove_entry(), - RawEntryMut::Vacant(_) => unreachable!(), - } - .0 - .0 - .try_as_arc_owned() - .unwrap(); + ManuallyDrop::into_inner( + match shard.raw_entry_mut().from_key_hashed_nocheck::(hash, arc.as_ref()) { + RawEntryMut::Occupied(occ) => occ.remove_entry(), + RawEntryMut::Vacant(_) => unreachable!(), + } + .0 + .0 + .try_as_arc_owned() + .unwrap(), + ); + debug_assert_eq!(Arc::count(&arc), 1); // Shrink the backing storage if the shard is less than 50% occupied. if shard.len() * 2 < shard.capacity() { @@ -219,7 +244,13 @@ impl Drop for Symbol { Self::drop_slow(&arc); } // decrement the ref count - drop(arc); + ManuallyDrop::into_inner(arc); + } +} + +impl Clone for Symbol { + fn clone(&self) -> Self { + Self { repr: increase_arc_refcount(self.repr) } } } @@ -228,8 +259,7 @@ fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr { return repr; }; // increase the ref count - mem::forget(arc.clone()); - mem::forget(arc); + mem::forget(Arc::clone(&arc)); repr } @@ -265,6 +295,7 @@ mod tests { let base_len = MAP.get().unwrap().len(); let hello = Symbol::intern("hello"); let world = Symbol::intern("world"); + let more_worlds = world.clone(); let bang = Symbol::intern("!"); let q = Symbol::intern("?"); assert_eq!(MAP.get().unwrap().len(), base_len + 4); @@ -275,6 +306,7 @@ mod tests { drop(q); assert_eq!(MAP.get().unwrap().len(), base_len + 3); let default = Symbol::intern("default"); + let many_worlds = world.clone(); assert_eq!(MAP.get().unwrap().len(), base_len + 3); assert_eq!( "hello default world!", @@ -285,6 +317,8 @@ mod tests { "hello world!", format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str()) ); + drop(many_worlds); + drop(more_worlds); drop(hello); drop(world); drop(bang); diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs index 04bf4dff50..af1af0536c 100644 --- a/crates/intern/src/symbol/symbols.rs +++ b/crates/intern/src/symbol/symbols.rs @@ -10,6 +10,7 @@ use crate::{ symbol::{SymbolProxy, TaggedArcPtr}, Symbol, }; + macro_rules! define_symbols { (@WITH_NAME: $($alias:ident = $value:literal),* $(,)? @PLAIN: $($name:ident),* $(,)?) => { $(