diff --git a/Cargo.lock b/Cargo.lock index 4a8b6f7..964004e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "atomic" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +dependencies = [ + "bytemuck", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -110,9 +119,23 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "byteorder" @@ -433,6 +456,7 @@ dependencies = [ name = "imctk-ids" version = "0.1.0" dependencies = [ + "bytemuck", "hashbrown 0.14.5", "imctk-derive", "imctk-transparent", @@ -506,6 +530,18 @@ dependencies = [ "zwohash", ] +[[package]] +name = "imctk_union_find" +version = "0.1.0" +dependencies = [ + "atomic", + "imctk-ids", + "imctk-lit", + "priority-queue", + "rand", + "rand_pcg", +] + [[package]] name = "indenter" version = "0.3.3" @@ -670,6 +706,17 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "priority-queue" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714c75db297bc88a63783ffc6ab9f830698a6705aa0201416931759ef4c8183d" +dependencies = [ + "autocfg", + "equivalent", + "indexmap", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 1fa671e..0c5247d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "imctk", "lit", "stable_set", + "union_find", # comment to force multi-line layout ] diff --git a/derive/src/id.rs b/derive/src/id.rs index fd76a6c..944332e 100644 --- a/derive/src/id.rs +++ b/derive/src/id.rs @@ -28,6 +28,7 @@ pub fn derive_id(input: DeriveInput, internal_generic_id: bool) -> syn::Result syn::Result { /// /// Users of this trait may depend on implementing types following these requirements for upholding /// their own safety invariants. -pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug { +pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug + NoUninit { /// An [`Id`] type that has the same representation and index range as this type. /// /// This is provided to enable writing generic code that during monomorphization is only @@ -161,6 +162,12 @@ pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug { #[repr(transparent)] pub struct GenericId(Repr); +// SAFETY: #[repr(transparent)] and the only field is explicitly required to be NoUninit +unsafe impl NoUninit for GenericId where + Repr: Id + NoUninit +{ +} + impl Debug for GenericId { #[inline(always)] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/ids/src/id/id_types.rs b/ids/src/id/id_types.rs index 1240ef6..8dfb1ec 100644 --- a/ids/src/id/id_types.rs +++ b/ids/src/id/id_types.rs @@ -1,12 +1,12 @@ mod id8 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxHighNibbleU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxHighNibbleU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xf0`. #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(transparent)] pub struct Id8(NonMaxHighNibbleU8); @@ -177,13 +177,13 @@ mod id8 { mod id16 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(2))] pub struct Id16 { lsb: u8, @@ -192,7 +192,7 @@ mod id16 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(2))] pub struct Id16 { msb: NonMaxU8, @@ -366,13 +366,13 @@ mod id16 { mod id32 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00_0000`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(4))] pub struct Id32 { lsbs: [u8; 3], @@ -381,7 +381,7 @@ mod id32 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(4))] pub struct Id32 { msb: NonMaxU8, @@ -562,13 +562,13 @@ mod id32 { mod id64 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00_0000_0000_0000`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(8))] pub struct Id64 { lsbs: [u8; 7], @@ -577,7 +577,7 @@ mod id64 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(8))] pub struct Id64 { msb: NonMaxU8, @@ -758,7 +758,7 @@ mod id64 { mod id_size { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxMsbU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxMsbU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; const LSBS: usize = (usize::BITS as usize / 8) - 1; @@ -766,7 +766,7 @@ mod id_size { /// [`Id`] type representing indices in the range `0..=isize::MAX as usize`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[cfg_attr(target_pointer_width = "16", repr(C, align(2)))] #[cfg_attr(target_pointer_width = "32", repr(C, align(4)))] #[cfg_attr(target_pointer_width = "64", repr(C, align(8)))] @@ -777,7 +777,7 @@ mod id_size { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[cfg_attr(target_pointer_width = "16", repr(C, align(2)))] #[cfg_attr(target_pointer_width = "32", repr(C, align(4)))] #[cfg_attr(target_pointer_width = "64", repr(C, align(8)))] diff --git a/ids/src/id/u8_range_types.rs b/ids/src/id/u8_range_types.rs index e5ba744..69a607f 100644 --- a/ids/src/id/u8_range_types.rs +++ b/ids/src/id/u8_range_types.rs @@ -1,5 +1,7 @@ +use crate::NoUninit; + #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxU8 { Val00 = 0x00, @@ -260,7 +262,7 @@ pub enum NonMaxU8 { } #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxHighNibbleU8 { Val00 = 0x00, @@ -506,7 +508,7 @@ pub enum NonMaxHighNibbleU8 { } #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxMsbU8 { Val00 = 0x00, diff --git a/ids/src/id_alloc.rs b/ids/src/id_alloc.rs new file mode 100644 index 0000000..630ed8f --- /dev/null +++ b/ids/src/id_alloc.rs @@ -0,0 +1,54 @@ +//! An allocator for IDs that allows concurrent access from multiple threads. +use std::sync::atomic::Ordering::Relaxed; +use std::{marker::PhantomData, sync::atomic::AtomicUsize}; + +use crate::{Id, IdRange}; + +/// An allocator for IDs that allows concurrent access from multiple threads. +pub struct IdAlloc { + counter: AtomicUsize, + _phantom: PhantomData, +} + +impl Default for IdAlloc { + fn default() -> Self { + Self::new() + } +} + +/// `IdAllocError` indicates that there are not enough IDs remaining. +#[derive(Clone, Copy, Debug)] +pub struct IdAllocError; + +impl IdAlloc { + /// Constructs a new ID allocator. + pub const fn new() -> Self { + Self { + counter: AtomicUsize::new(0), + _phantom: PhantomData, + } + } + fn alloc_indices(&self, n: usize) -> Result { + self.counter + .fetch_update(Relaxed, Relaxed, |current_id| { + current_id + .checked_add(n) + .filter(|&index| index <= T::MAX_ID_INDEX.saturating_add(1)) + }) + .map_err(|_| IdAllocError) + } + /// Allocates a single ID. + pub fn alloc(&self) -> Result { + self.alloc_indices(1).map(|index| { + // SAFETY: the precondition was checked by `alloc_indices` + unsafe { T::from_id_index_unchecked(index) } + }) + } + /// Allocates a contiguous range of the specified size. + pub fn alloc_range(&self, n: usize) -> Result, IdAllocError> { + self.alloc_indices(n).map(|start| { + // SAFETY: the precondition was checked by `alloc_indices` + unsafe { IdRange::from_index_range_unchecked(start..start + n) } + }) + } +} diff --git a/ids/src/id_vec.rs b/ids/src/id_vec.rs index ff04a19..9c40049 100644 --- a/ids/src/id_vec.rs +++ b/ids/src/id_vec.rs @@ -312,7 +312,7 @@ impl Clone for IdVec { } } -impl Default for IdVec { +impl Default for IdVec { #[inline(always)] fn default() -> Self { Self { diff --git a/ids/src/lib.rs b/ids/src/lib.rs index 931307a..da26bd5 100644 --- a/ids/src/lib.rs +++ b/ids/src/lib.rs @@ -16,6 +16,7 @@ mod id; mod id_range; pub mod id_vec; pub mod indexed_id_vec; +pub mod id_alloc; pub mod id_set_seq; @@ -31,3 +32,9 @@ pub use imctk_derive::Id; pub use id::{ConstIdFromIdIndex, GenericId, Id, Id16, Id32, Id64, Id8, IdSize}; pub use id_range::IdRange; + +pub use id_alloc::IdAlloc; + +// re-export this so that others can use it without depending on bytemuck explicitly +// in particular needed for #[derive(Id)] +pub use bytemuck::NoUninit; \ No newline at end of file diff --git a/ids/tests/test_id.rs b/ids/tests/test_id.rs index 6df140f..0292768 100644 --- a/ids/tests/test_id.rs +++ b/ids/tests/test_id.rs @@ -383,7 +383,7 @@ fn conversion_usize() { #[derive(Id)] #[repr(transparent)] -pub struct NewtypePhantom(usize, PhantomData); +pub struct NewtypePhantom(usize, PhantomData); impl std::fmt::Debug for NewtypePhantom { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/lit/Cargo.toml b/lit/Cargo.toml index 925485a..f9b5cf0 100644 --- a/lit/Cargo.toml +++ b/lit/Cargo.toml @@ -18,4 +18,4 @@ imctk-transparent = { version = "0.1.0", path = "../transparent" } imctk-util = { version = "0.1.0", path = "../util" } log = "0.4.21" table_seq = { version = "0.1.0", path = "../table_seq" } -zwohash = "0.1.2" +zwohash = "0.1.2" \ No newline at end of file diff --git a/union_find/Cargo.toml b/union_find/Cargo.toml new file mode 100644 index 0000000..dfc7004 --- /dev/null +++ b/union_find/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "imctk_union_find" +version = "0.1.0" +edition = "2021" + +publish = false + +[dependencies] +imctk-ids = { version = "0.1.0", path = "../ids" } +imctk-lit = { version = "0.1.0", path = "../lit" } +atomic = "0.6" +priority-queue = "*" + +[dev-dependencies] +rand = "*" +rand_pcg = "*" + +[lints] +workspace = true diff --git a/union_find/src/element.rs b/union_find/src/element.rs new file mode 100644 index 0000000..d224e15 --- /dev/null +++ b/union_find/src/element.rs @@ -0,0 +1,57 @@ +//! A trait for "elements" that can be split into an "atom" and a "polarity". + +use imctk_lit::{Lit, Var}; + +/// A trait for "elements" that can be split into an "atom" and a "polarity". +/// +/// This lets code generically manipulate variables and literals, and further serves to abstract over their concrete representation. +/// +/// The two most common case this trait is used for are: +/// 1) The element and the atom are both `Var`. In this case there is only one polarity and the trait implementation is trivial. +/// 2) The element is `Lit` and the atom is `Var`. Here there are two polarities (`+` and `-`) to keep track of. +/// +/// Mathematically, implementing this trait signifies that elements can be written as pairs `(a, p)` with an atom `a` and a polarity `p`. +/// The polarities are assumed to form a group `(P, *, 1)`. The trait operations then correspond to: +/// 1) `from_atom(a) = (a, 1)` +/// 2) `atom((a, p)) = a` +/// 3) `apply_pol_of((a, p), (b, q)) = (a, p * q)` +/// +/// Currently, code assumes that `P` is either trivial or isomorphic to `Z_2`. +/// +/// Code using this trait may assume the following axioms to hold: +/// 1) `from_atom(atom(x)) == x` +/// 2) `apply_pol_of(atom(x), x) == x` +/// 3) `apply_pol_of(apply_pol_of(x, y), y) == x` +// TODO: add missing axioms +pub trait Element { + /// Constructs an element from an atom by applying positive polarity. + fn from_atom(atom: Atom) -> Self; + /// Returns the atom corresponding to an element. + fn atom(self) -> Atom; + /// Multiplies `self` by the polarity of `other`, i.e. conceptually `apply_pol_of(self, other) = self ^ pol(other)`. + fn apply_pol_of(self, other: Self) -> Self; +} + +impl Element for T { + fn from_atom(atom: T) -> Self { + atom + } + fn atom(self) -> T { + self + } + fn apply_pol_of(self, _other: T) -> Self { + self + } +} + +impl Element for Lit { + fn from_atom(atom: Var) -> Self { + atom.as_lit() + } + fn atom(self) -> Var { + self.var() + } + fn apply_pol_of(self, other: Self) -> Self { + self ^ other.pol() + } +} diff --git a/union_find/src/lib.rs b/union_find/src/lib.rs new file mode 100644 index 0000000..f4f90ca --- /dev/null +++ b/union_find/src/lib.rs @@ -0,0 +1,13 @@ +//! This crate defines a structure [`UnionFind`] that allows tracking of equivalences between generic elements +//! and a structure [`TrackedUnionFind`] that provides the same functionality but augmented by change tracking. + +#[doc(inline)] +pub use element::Element; +#[doc(inline)] +pub use tracked_union_find::TrackedUnionFind; +#[doc(inline)] +pub use union_find::UnionFind; + +pub mod element; +pub mod tracked_union_find; +pub mod union_find; diff --git a/union_find/src/tests/test_tracked_union_find.rs b/union_find/src/tests/test_tracked_union_find.rs new file mode 100644 index 0000000..33ba01d --- /dev/null +++ b/union_find/src/tests/test_tracked_union_find.rs @@ -0,0 +1,287 @@ +#![allow(missing_docs, dead_code)] +use super::*; +use imctk_lit::{Lit, Var}; +use rand::prelude::*; +use std::collections::HashMap; + +fn change_eq(c1: &Change, c2: &Change) -> bool { + match (c1, c2) { + ( + Change::Union { + new_repr: a1, + merged_repr: b1, + }, + Change::Union { + new_repr: a2, + merged_repr: b2, + }, + ) => a1 == a2 && b1 == b2, + ( + Change::MakeRepr { + new_repr: a1, + old_repr: b1, + }, + Change::MakeRepr { + new_repr: a2, + old_repr: b2, + }, + ) => a1 == a2 && b1 == b2, + (Change::Renumber(r1), Change::Renumber(r2)) => Arc::as_ptr(r1) == Arc::as_ptr(r2), + _ => false, + } +} + +struct CTUnionFind { + dut: TrackedUnionFind, + uf: UnionFind, + logs: HashMap>>, +} + +impl> CTUnionFind { + fn new() -> Self { + CTUnionFind { + dut: TrackedUnionFind::new(), + uf: UnionFind::new(), + logs: HashMap::new(), + } + } + fn find(&mut self, e: Elem) -> Elem { + let dut_result = self.dut.find(e); + let uf_result = self.uf.find(e); + assert_eq!(dut_result, uf_result); + uf_result + } + fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let (uf_ok, [uf_ra, uf_rb]) = self.uf.union_full(elems); + if uf_ok { + let new_repr = uf_ra.atom(); + let merged_repr = uf_rb.apply_pol_of(uf_ra); + for vec in self.logs.values_mut() { + vec.push(Change::Union { + new_repr, + merged_repr, + }); + } + } + let (dut_ok, [dut_ra, dut_rb]) = self.dut.union_full(elems); + assert_eq!((dut_ok, [dut_ra, dut_rb]), (uf_ok, [uf_ra, uf_rb])); + (uf_ok, [uf_ra, uf_rb]) + } + fn make_repr(&mut self, new_repr: Atom) -> Elem { + let uf_old_repr = self.uf.make_repr(new_repr); + if uf_old_repr.atom() != new_repr { + for vec in self.logs.values_mut() { + vec.push(Change::MakeRepr { + new_repr, + old_repr: uf_old_repr, + }); + } + } else { + assert!(uf_old_repr == Elem::from_atom(new_repr)); + } + let dut_old_repr = self.dut.make_repr(new_repr); + assert_eq!(dut_old_repr, uf_old_repr); + uf_old_repr + } + fn start_observing(&mut self) -> ObserverToken { + let token = self.dut.start_observing(); + assert!(self.logs.insert(token.observer_id, Vec::new()).is_none()); + token + } + fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { + let new_token = self.dut.clone_token(token); + let cloned_log = self.logs.get(&token.observer_id).unwrap().clone(); + assert!(self + .logs + .insert(new_token.observer_id, cloned_log) + .is_none()); + new_token + } + fn stop_observing(&mut self, token: ObserverToken) { + assert!(self.logs.remove(&token.observer_id).is_some()); + self.dut.stop_observing(token); + } + fn drain_changes_with_fn( + &mut self, + token: &mut ObserverToken, + mut f: impl FnMut(&[Change]), + ) { + let log = self.logs.get_mut(&token.observer_id).unwrap(); + let mut log_iter = log.drain(..); + self.dut.drain_changes_with_fn(token, |changes, _| { + f(changes); + for c1 in changes.iter() { + let Some(c2) = log_iter.next() else { + panic!("not enough changes"); + }; + assert!(change_eq(c1, &c2)); + } + }); + assert!(log_iter.next().is_none()); + } + fn drain_some_changes( + &mut self, + token: &mut ObserverToken, + calculate_count: impl FnOnce(usize) -> usize, + ) -> usize { + let log = self.logs.get_mut(&token.observer_id).unwrap(); + let count = calculate_count(log.len()); + let log_iter = log.drain(..count); + let mut gen = token.generation; + let mut dut_iter = self.dut.drain_changes(token); + for c2 in log_iter { + let Some(c1) = dut_iter.next() else { + panic!("not enough changes"); + }; + assert!(change_eq(c1, &c2)); + if let Change::Renumber(renumbering) = c1 { + assert_eq!(renumbering.old_generation, gen); + gen = renumbering.new_generation; + } + } + dut_iter.stop(); + assert_eq!(token.generation, gen); + count + } + fn repr_reduction(&self) -> (IdVec>, IdVec) { + let mut forward: IdVec> = IdVec::default(); + let mut reverse: IdVec = IdVec::default(); + for (atom, repr) in self.uf.iter() { + let new_repr = *forward.grow_for_key(repr.atom()).get_or_insert_with(|| { + let new_repr = reverse.push(Elem::from_atom(repr.atom())).0; + Elem::from_atom(new_repr) + }); + forward + .grow_for_key(atom) + .replace(new_repr.apply_pol_of(repr)); + } + (forward, reverse) + } + fn renumber(&mut self) -> Arc> { + let (forward, reverse) = self.repr_reduction(); + let renumbering = self.dut.renumber(forward, reverse); + for vec in self.logs.values_mut() { + vec.push(Change::Renumber(renumbering.clone())); + } + self.uf = UnionFind::new(); + renumbering + } +} + +macro_rules! weighted_choose { + ($rng:expr, $($name:ident: $weight:expr => $body:expr),+) => { + { + enum Branches { $( $name, )* } + let weights = [$((Branches::$name, $weight)),+]; + match weights.choose_weighted($rng, |x| x.1).unwrap().0 { + $(Branches::$name => $body),* + } + } + } +} + +#[test] +fn test_suite() { + let mut u: CTUnionFind = CTUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let mut active_tokens: Vec = Vec::new(); + let max_var = 2000; + let verbosity = 2; + for _ in 0..2000 { + weighted_choose! {&mut rng, + Union: 8.0 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + if verbosity > 1 { + println!("union({a}, {b}) = {result:?}"); + } + }, + Find: 1.0 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + if verbosity > 1 { + println!("find({a}) = {result}"); + } + }, + MakeRepr: 2.0 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + if verbosity > 1 { + println!("make_repr({a})"); + } + }, + StartObserving: 1.0 => { + let token = u.start_observing(); + if verbosity > 0 { + println!("start observing {token:?}"); + } + active_tokens.push(token); + }, + StopObserving: 1.0 => { + if let Some(index) = (0..active_tokens.len()).choose(&mut rng) { + let token = active_tokens.swap_remove(index); + if verbosity > 0 { + println!("stop observing {token:?}"); + } + u.stop_observing(token); + } + }, + CloneToken: 1.0 => { + if let Some(token) = active_tokens.iter().choose(&mut rng) { + let new_token = u.clone_token(token); + if verbosity > 0 { + println!("clone {token:?} -> {new_token:?}"); + } + active_tokens.push(new_token); + } + }, + DrainAllChanges: 2.0 => { + if let Some(token) = active_tokens.iter_mut().choose(&mut rng) { + let mut count = 0; + let mut gen = token.generation; + u.drain_changes_with_fn(token, |changes| { + for change in changes { + if let Change::Renumber(renumbering) = change { + assert_eq!(renumbering.old_generation, gen); + gen = renumbering.new_generation; + } + count += 1; + } + }); + assert_eq!(gen, token.generation); + if verbosity > 0 { + println!("drained all ({count}) changes from {token:?}"); + } + } + }, + DrainSomeChanges: 2.0 => { + if let Some(token) = active_tokens.iter_mut().choose(&mut rng) { + let count = u.drain_some_changes(token, |total| (0..=total).choose(&mut rng).unwrap()); + if verbosity > 0 { + println!("drained {count} changes from {token:?}"); + } + } + }, + Renumber: 0.25 => { + let renumbering = u.renumber(); + if verbosity > 0 { + println!("renumbered all variables, now on generation {} ({} old variables, {} new variables)", + u.dut.generation.0, + renumbering.forward.len(), + renumbering.reverse.len() + ); + } + } + } + } +} + +#[test] +#[should_panic] +fn test_token_error() { + let mut tuf1: TrackedUnionFind = Default::default(); + let mut tuf2: TrackedUnionFind = Default::default(); + let mut token = tuf1.start_observing(); + tuf2.drain_changes(&mut token); +} diff --git a/union_find/src/tests/test_union_find.rs b/union_find/src/tests/test_union_find.rs new file mode 100644 index 0000000..5f8a53e --- /dev/null +++ b/union_find/src/tests/test_union_find.rs @@ -0,0 +1,176 @@ +#![allow(dead_code, missing_docs)] + +use super::*; +use imctk_ids::id_set_seq::IdSetSeq; +use imctk_lit::{Lit, Var}; +use rand::prelude::*; +use std::collections::{HashSet, VecDeque}; + +#[derive(Default)] +struct CheckedUnionFind { + dut: UnionFind, + equivs: IdSetSeq, +} + +impl> UnionFind { + fn debug_print_tree( + children: &IdVec>, + atom: Atom, + prefix: &str, + self_char: &str, + further_char: &str, + pol: bool, + ) { + println!( + "{prefix}{self_char}{}{:?}", + if pol { "!" } else { "" }, + atom + ); + let my_children = children.get(atom).unwrap(); + for (index, &child) in my_children.iter().enumerate() { + let last = index == my_children.len() - 1; + let self_char = if last { "└" } else { "├" }; + let next_further_char = if last { " " } else { "│" }; + Self::debug_print_tree( + children, + child.atom(), + &(prefix.to_string() + further_char), + self_char, + next_further_char, + pol ^ (child != Elem::from_atom(child.atom())), + ); + } + } + fn debug_print(&self) { + let mut children: IdVec> = Default::default(); + for atom in self.parent.keys() { + let parent = self.read_parent(atom); + children.grow_for_key(atom); + if atom != parent.atom() { + children + .grow_for_key(parent.atom()) + .push(Elem::from_atom(atom).apply_pol_of(parent)); + } else { + assert!(Elem::from_atom(atom) == parent); + } + } + for atom in self.parent.keys() { + if atom == self.read_parent(atom).atom() { + Self::debug_print_tree(&children, atom, "", "", " ", false); + } + } + } +} +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +enum VarRel { + Equiv, + AntiEquiv, + NotEquiv, +} + +impl> CheckedUnionFind { + fn new() -> Self { + CheckedUnionFind { + dut: Default::default(), + equivs: Default::default(), + } + } + fn ref_equal(&mut self, start: Elem, goal: Elem) -> VarRel { + let mut seen: HashSet = Default::default(); + let mut queue: VecDeque = [start].into(); + while let Some(place) = queue.pop_front() { + if place.atom() == goal.atom() { + if place == goal { + return VarRel::Equiv; + } else { + return VarRel::AntiEquiv; + } + } + seen.insert(place.atom()); + for &next in self.equivs.grow_for(place.atom()).iter() { + if !seen.contains(&next.atom()) { + queue.push_back(next.apply_pol_of(place)); + } + } + } + VarRel::NotEquiv + } + fn find(&mut self, lit: Elem) -> Elem { + let out = self.dut.find(lit); + assert!(self.ref_equal(lit, out) == VarRel::Equiv); + out + } + fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, [ra, rb]) = self.dut.union_full(lits); + assert_eq!(self.ref_equal(lits[0], ra), VarRel::Equiv); + assert_eq!(self.ref_equal(lits[1], rb), VarRel::Equiv); + assert_eq!(ok, self.ref_equal(lits[0], lits[1]) == VarRel::NotEquiv); + assert_eq!(self.dut.find_root(lits[0]), ra); + if ok { + assert_eq!(self.dut.find_root(lits[1]), ra); + self.equivs + .grow_for(lits[0].atom()) + .insert(lits[1].apply_pol_of(lits[0])); + self.equivs + .grow_for(lits[1].atom()) + .insert(lits[0].apply_pol_of(lits[1])); + } else { + assert_eq!(self.dut.find_root(lits[1]).atom(), ra.atom()); + } + (ok, [ra, rb]) + } + fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + fn make_repr(&mut self, lit: Atom) { + self.dut.make_repr(lit); + assert_eq!( + self.dut.find_root(Elem::from_atom(lit)), + Elem::from_atom(lit) + ); + self.check(); + } + fn check(&mut self) { + for atom in self.dut.parent.keys() { + let parent = self.dut.read_parent(atom); + assert_eq!(self.ref_equal(Elem::from_atom(atom), parent), VarRel::Equiv); + let root = self.dut.find_root(Elem::from_atom(atom)); + for &child in self.equivs.grow_for(atom).iter() { + assert_eq!(root, self.dut.find_root(child)); + } + } + } +} + +#[test] +fn test_suite() { + let mut u: CheckedUnionFind = CheckedUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let max_var = 2000; + for _ in 0..2000 { + match rng.gen_range(0..10) { + 0..=4 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + println!("union({a}, {b}) = {result:?}"); + } + 5..=7 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + println!("find({a}) = {result}"); + } + 8 => { + u.check(); + } + 9 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + println!("make_repr({a})"); + } + _ => {} + } + } + u.check(); + //u.dut.debug_print(); +} diff --git a/union_find/src/tracked_union_find.rs b/union_find/src/tracked_union_find.rs new file mode 100644 index 0000000..343f74a --- /dev/null +++ b/union_find/src/tracked_union_find.rs @@ -0,0 +1,581 @@ +//! A `TrackedUnionFind` augments a [`UnionFind`] structure with change tracking. +use std::{cmp::Reverse, collections::VecDeque, iter::FusedIterator, mem::ManuallyDrop, sync::Arc}; + +use imctk_ids::{id_vec::IdVec, Id, Id64, IdAlloc}; +use priority_queue::PriorityQueue; + +use crate::{Element, UnionFind}; + +#[cfg(test)] +#[path = "tests/test_tracked_union_find.rs"] +mod test_tracked_union_find; + +/// Globally unique ID for a `TrackedUnionFind`. +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct TrackedUnionFindId(Id64); +/// Generation number for a `TrackedUnionFind` +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct Generation(u64); +/// Observer ID for a `TrackedUnionFind` (not globally unique) +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct ObserverId(Id64); + +/// An `ObserverToken` represents an observer of a `TrackedUnionFind`. +pub struct ObserverToken { + tuf_id: TrackedUnionFindId, + generation: Generation, + observer_id: ObserverId, +} + +impl ObserverToken { + /// Returns the ID of the associated `TrackedUnionFind`. + pub fn tuf_id(&self) -> TrackedUnionFindId { + self.tuf_id + } + /// Returns the ID of the generation that this observer is on. + pub fn generation(&self) -> Generation { + self.generation + } + /// Returns the ID of the observer. + /// + /// Note that observer IDs are local to each `TrackedUnionFind`. + pub fn observer_id(&self) -> ObserverId { + self.observer_id + } + /// Returns `true` iff `self` and `other` belong to the same `TrackedUnionFind`. + /// + /// NB: This does **not** imply that variable IDs are compatible, for that you want `is_compatible`. + pub fn is_same_tuf(&self, other: &ObserverToken) -> bool { + self.tuf_id == other.tuf_id + } + /// Returns `true` iff `self` and `other` have compatible variable IDs. + /// + /// This is equivalent to whether they are on the same generation of the same `TrackedUnionFind`. + pub fn is_compatible(&self, other: &ObserverToken) -> bool { + self.tuf_id == other.tuf_id && self.generation == other.generation + } +} + +impl std::fmt::Debug for ObserverToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ObserverToken") + .field("tuf_id", &self.tuf_id.0) + .field("generation", &self.generation.0) + .field("observer_id", &self.observer_id.0) + .finish() + } +} + +/// `Renumbering` represents a renumbering of all variables in `TrackedUnionFind`. +/// +/// A renumbering stores a forward and a reverse mapping, and the old and new generation IDs. +/// +/// The forward mapping maps each old variable to an optional new variable (variables may be deleted). +/// It is a requirement that equivalent old variables are either both mapped to the same new variable or both deleted. +/// +/// The reverse mapping maps each new variable to its old representative. +/// The new set of variables is required to be contiguous, hence `reverse` is a total mapping. +pub struct Renumbering { + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, +} + +impl std::fmt::Debug for Renumbering { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Renumbering") + .field("forward", &self.forward) + .field("reverse", &self.reverse) + .field("old_generation", &self.old_generation) + .field("new_generation", &self.new_generation) + .finish() + } +} + +impl> Renumbering { + /// Returns the inverse of a reassignment of variables. + pub fn get_reverse( + forward: &IdVec>, + union_find: &UnionFind, + ) -> IdVec { + let mut reverse: IdVec> = IdVec::default(); + for (old, &new_opt) in forward { + if let Some(new) = new_opt { + reverse + .grow_for_key(new.atom()) + .replace(union_find.find(Elem::from_atom(old)).apply_pol_of(new)); + } + } + IdVec::from_vec(reverse.iter().map(|x| x.1.unwrap()).collect()) + } + /// Returns `true` iff the arguments are inverses of each other. + pub fn is_inverse(forward: &IdVec>, reverse: &IdVec) -> bool { + reverse.iter().all(|(new, &old)| { + if let Some(&Some(e)) = forward.get(old.atom()) { + Elem::from_atom(new) == e.apply_pol_of(old) + } else { + false + } + }) + } + /// Creates a renumbering without checking whether the arguments are valid. + pub fn new_unchecked( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + Renumbering { + forward, + reverse, + old_generation, + new_generation, + } + } + /// Creates a new renumbering from the given forward and reverse assignment. + pub fn new( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + debug_assert!(new_generation > old_generation); + debug_assert!(Self::is_inverse(&forward, &reverse)); + Self::new_unchecked(forward, reverse, old_generation, new_generation) + } + /// Returns the new variable corresponding to the given old variable, if it exists. + pub fn old_to_new(&self, old: Elem) -> Option { + self.forward + .get(old.atom()) + .copied() + .flatten() + .map(|e| e.apply_pol_of(old)) + } + /// Returns the old variable corresponding to the given new variable, if it exists. + pub fn new_to_old(&self, new: Elem) -> Option { + self.reverse.get(new.atom()).map(|&e| e.apply_pol_of(new)) + } + /// Returns `true` iff the given renumbering satisfies the constraint that equivalent variables are mapped identically. + pub fn is_repr_reduction(&self, union_find: &UnionFind) -> bool { + union_find.lowest_unused_atom() <= self.forward.next_unused_key() + && self.forward.iter().all(|(old, &new)| { + let repr = union_find.find(Elem::from_atom(old)); + let repr_new = self.old_to_new(repr); + repr_new == new + }) + } +} + +/// `Change` represents a single change of a `TrackedUnionFind` +#[allow(missing_docs)] // dont want to document every subfield +#[derive(Clone)] +pub enum Change { + /// A `union` operation. The set with representative `merged_repr` is merged into the set with representative `new_repr`. + Union { new_repr: Atom, merged_repr: Elem }, + /// A `make_repr` operation. `new_repr` is promoted to be the representative of its set, replacing `old_repr`. + MakeRepr { new_repr: Atom, old_repr: Elem }, + /// A renumbering operation. + Renumber(Arc>), +} + +impl std::fmt::Debug for Change { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Union { + new_repr, + merged_repr, + } => f + .debug_struct("Union") + .field("new_repr", new_repr) + .field("merged_repr", merged_repr) + .finish(), + Self::MakeRepr { new_repr, old_repr } => f + .debug_struct("MakeRepr") + .field("new_repr", new_repr) + .field("old_repr", old_repr) + .finish(), + Self::Renumber(arg0) => f.debug_tuple("Renumber").field(arg0).finish(), + } + } +} + +/// A `TrackedUnionFind` augments a [`UnionFind`] structure with change tracking. +pub struct TrackedUnionFind { + /// Globally unique ID. + tuf_id: TrackedUnionFindId, + union_find: UnionFind, + /// Log of changes with new changes appended to the end. + /// Indices into this log are relative positions. + /// The start of this log has relative position 0 and absolute position `log_start`. + log: VecDeque>, + observer_id_alloc: IdAlloc, + /// This stores absolute positions of each observers in the log, + /// and also allows retrieving the minimum absolute position of any observer. + /// + /// `truncate_log` will reset that minimum position to be at relative position 0. + observers: PriorityQueue>, + /// Offset between absolute and relative positions in the log. + /// + /// absolute position = relative position + `log_start` + log_start: u64, + /// Incremented on every renumbering. + generation: Generation, +} + +static TUF_ID_ALLOC: IdAlloc = IdAlloc::new(); + +impl From> for TrackedUnionFind { + fn from(union_find: UnionFind) -> Self { + Self { + union_find, + tuf_id: TUF_ID_ALLOC.alloc().unwrap(), + log: Default::default(), + observer_id_alloc: Default::default(), + observers: Default::default(), + log_start: 0, + generation: Generation(0), + } + } +} + +impl Default for TrackedUnionFind { + fn default() -> Self { + UnionFind::default().into() + } +} + +impl TrackedUnionFind { + /// Returns a shared reference to the contained `UnionFind`. + pub fn get_union_find(&self) -> &UnionFind { + &self.union_find + } + /// Returns the contained `UnionFind`. All change tracking data is lost. + pub fn into_union_find(self) -> UnionFind { + self.union_find + } +} + +impl> TrackedUnionFind { + /// Returns an element's representative. See [`UnionFind::find`]. + pub fn find(&self, elem: Elem) -> Elem { + self.union_find.find(elem) + } + /// Declares two elements to be equivalent. See [`UnionFind::union_full`]. + pub fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, roots) = self.union_find.union_full(elems); + if ok && !self.observers.is_empty() { + let new_repr = roots[0].atom(); + let merged_repr = roots[1].apply_pol_of(roots[0]); + self.log.push_back(Change::Union { + new_repr, + merged_repr, + }) + } + (ok, roots) + } + /// Declares two elements to be equivalent. See [`UnionFind::union`]. + pub fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + /// Promotes an atom to be a representative. See [`UnionFind::make_repr`]. + pub fn make_repr(&mut self, new_repr: Atom) -> Elem { + let old_repr = self.union_find.make_repr(new_repr); + if old_repr.atom() != new_repr && !self.observers.is_empty() { + self.log.push_back(Change::MakeRepr { new_repr, old_repr }); + } + old_repr + } + /// Renumbers all the variables in the `UnionFind`. + /// The provided mapping **must** meet all the preconditions listed for [`Renumbering`]. + /// + /// This resets the `UnionFind` to the trivial state (`find(a) == a` for all `a`) and increments the generation ID. + /// + /// This method will panic in debug mode if said preconditions are not met. + pub fn renumber( + &mut self, + forward: IdVec>, + reverse: IdVec, + ) -> Arc> { + let old_generation = self.generation; + let new_generation = Generation(old_generation.0 + 1); + self.generation = new_generation; + let renumbering = Arc::new(Renumbering::new( + forward, + reverse, + old_generation, + new_generation, + )); + debug_assert!(renumbering.is_repr_reduction(&self.union_find)); + if !self.observers.is_empty() { + self.log.push_back(Change::Renumber(renumbering.clone())); + } + self.union_find = UnionFind::new(); + renumbering + } +} + +impl TrackedUnionFind { + /// Constructs a new, empty `TrackedUnionFind`. + pub fn new() -> Self { + Self::default() + } + fn log_end(&self) -> u64 { + self.log_start + self.log.len() as u64 + } + /// Creates a new `ObserverToken` that can be used to track all changes since the call to this method. + /// + /// Conceptually, each observer has its own private log. + /// Any changes that happen to the `UnionFind` will be recorded into the logs of all currently active observers. + /// (In the actual implementation, only a single log is kept). + /// + /// After use, the `ObserverToken` must be disposed of with a call to `stop_observing`, otherwise + /// the memory corresponding to old log entries cannot be reclaimed until the `TrackedUnionFind` is dropped. + pub fn start_observing(&mut self) -> ObserverToken { + let observer_id = self.observer_id_alloc.alloc().unwrap(); + self.observers.push(observer_id, Reverse(self.log_end())); + ObserverToken { + tuf_id: self.tuf_id, + generation: self.generation, + observer_id, + } + } + /// Clones an `ObserverToken`, conceptually cloning the token's private log. + pub fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { + assert!(token.tuf_id == self.tuf_id); + let new_observer_id = self.observer_id_alloc.alloc().unwrap(); + let pos = *self.observers.get_priority(&token.observer_id).unwrap(); + self.observers.push(new_observer_id, pos); + ObserverToken { + tuf_id: self.tuf_id, + generation: token.generation, + observer_id: new_observer_id, + } + } + /// Deletes an `ObserverToken` and its associated state. + /// + /// You must call this to allow the `TrackedUnionFind` to allow memory to be reclaimed. + pub fn stop_observing(&mut self, token: ObserverToken) { + assert!(token.tuf_id == self.tuf_id); + self.observers.remove(&token.observer_id); + self.truncate_log(); + } + fn truncate_log(&mut self) { + if let Some((_, &Reverse(new_start))) = self.observers.peek() { + if new_start > self.log_start { + let delete = (new_start - self.log_start).try_into().unwrap(); + drop(self.log.drain(0..delete)); + self.log_start = new_start; + } + } else { + self.log_start = self.log_end(); + self.log.clear(); + } + } + fn observer_rel_pos(&self, token: &ObserverToken) -> usize { + assert!(token.tuf_id == self.tuf_id); + let abs_pos = self.observers.get_priority(&token.observer_id).unwrap().0; + debug_assert!(abs_pos >= self.log_start); + (abs_pos - self.log_start).try_into().unwrap() + } + fn observer_inc_pos(&mut self, token: &ObserverToken, by: u64) { + let (min, max) = (self.log_start, self.log_end()); + self.observers + .change_priority_by(&token.observer_id, |pos| { + let new = pos.0 + by; + debug_assert!(new >= min && new <= max); + *pos = Reverse(new); + }); + self.truncate_log(); + } + fn observer_set_rel_pos(&mut self, token: &ObserverToken, rel_pos: usize) { + debug_assert!(rel_pos <= self.log.len()); + let abs_pos = self.log_start + rel_pos as u64; + self.observers + .change_priority(&token.observer_id, Reverse(abs_pos)); + self.truncate_log(); + } + #[allow(clippy::type_complexity)] + fn change_slices( + &self, + token: &ObserverToken, + ) -> ( + &[Change], + &[Change], + &UnionFind, + ) { + let rel_pos = self.observer_rel_pos(token); + let (first, second) = self.log.as_slices(); + if rel_pos >= first.len() { + (&second[rel_pos - first.len()..], &[], &self.union_find) + } else { + (&first[rel_pos..], second, &self.union_find) + } + } + fn observer_has_seen(&self, token: &mut ObserverToken, changes: &[Change]) { + for change in changes { + if let Change::Renumber(renumbering) = change { + debug_assert!(token.generation == renumbering.old_generation); + token.generation = renumbering.new_generation; + } + } + } + /// Calls the provided function `f` with the content of the token's private log and clears the log. + /// + /// Because the log is not necessarily contiguous in memory, `f` may be called multiple times. + /// + /// The slice argument to `f` is guaranteed to be non-empty. + /// To allow looking up representatives `f` is also provided with a shared reference to the `UnionFind`. + /// + /// The method assumes that you will immediately process any `Renumbering` operations in the log + /// and will update the token's generation field. + /// + /// Returns `true` iff `f` has been called at least once. + pub fn drain_changes_with_fn( + &mut self, + token: &mut ObserverToken, + mut f: impl FnMut(&[Change], &UnionFind), + ) -> bool { + let (first, second, union_find) = self.change_slices(token); + if !first.is_empty() { + f(first, union_find); + self.observer_has_seen(token, first); + if !second.is_empty() { + f(second, union_find); + self.observer_has_seen(token, second); + } + let drained = first.len() + second.len(); + self.observer_inc_pos(token, drained as u64); + true + } else { + false + } + } + /// Returns a draining iterator that returns and deletes entries from the token's private log. + /// + /// Dropping this iterator will clear any unread entries, call `stop` if this is undesirable. + /// + /// You must not leak the returned iterator. Otherwise log entries may be observed multiple times and appear duplicated. + pub fn drain_changes<'a>( + &mut self, + token: &'a mut ObserverToken, + ) -> DrainChanges<'_, 'a, Atom, Elem> { + assert!(token.tuf_id == self.tuf_id); + let rel_pos = self.observer_rel_pos(token); + DrainChanges { + tuf: self, + token, + rel_pos, + } + } +} + +/// A draining iterator. +/// +/// Since this is a lending iterator, it does not implement the standard `Iterator` trait, +/// but its `map` and `cloned` methods will create a standard iterator. +pub struct DrainChanges<'a, 'b, Atom, Elem> { + tuf: &'a mut TrackedUnionFind, + token: &'b mut ObserverToken, + rel_pos: usize, +} + +/// A draining iterator that has been mapped. +pub struct DrainChangesMap<'a, 'b, Atom, Elem, F> { + inner: DrainChanges<'a, 'b, Atom, Elem>, + f: F, +} + +impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { + /// Returns a reference to the first entry in the token's private log, without deleting it. + /// + /// Returns `None` if the log is empty. + pub fn peek(&mut self) -> Option<&Change> { + self.tuf.log.get(self.rel_pos) + } + /// Returns a reference to the first entry in the token's private log. The entry will be deleted after its use. + /// + /// If this returns a `Renumbering`, it is assumed that you will process it and the token's generation number will be updated. + /// + /// Returns `None` if the log is empty. If `next` returned `None`, it will never return any more entries (the iterator is fused). + /// + /// As reflected by the lifetimes, the API only guarantees that the returned reference until the next call of any method of this iterator. + /// (In practice, deletion is more lazy and happens on drop). + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<&Change> { + let ret = self.tuf.log.get(self.rel_pos); + if let Some(change) = ret { + self.tuf + .observer_has_seen(self.token, std::slice::from_ref(change)); + self.rel_pos += 1; + } + ret + } + /// Drops the iterator but without deleting unread entries. + pub fn stop(self) { + self.tuf.observer_set_rel_pos(self.token, self.rel_pos); + let _ = ManuallyDrop::new(self); + } + /// Returns `(n, Some(n))` where `n` is the number of unread entries. + /// + /// This method is designed to be compatible with the standard iterator method of the same name. + pub fn size_hint(&self) -> (usize, Option) { + let count = self.tuf.log.len() - self.rel_pos; + (count, Some(count)) + } + /// Creates a new iterator by lazily calling `f` on every change. + #[must_use] + pub fn map(self, f: F) -> DrainChangesMap<'a, 'b, Atom, Elem, F> + where + F: FnMut(&Change) -> B, + { + DrainChangesMap { inner: self, f } + } +} + +impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { + /// Create a standard iterator by cloning every entry. + #[must_use] + #[allow(clippy::type_complexity)] + pub fn cloned( + self, + ) -> DrainChangesMap<'a, 'b, Atom, Elem, fn(&Change) -> Change> { + self.map(|x| x.clone()) + } +} + +impl Drop for DrainChanges<'_, '_, Atom, Elem> { + fn drop(&mut self) { + // mark any renumberings as seen + while self.next().is_some() {} + self.tuf + .observer_set_rel_pos(self.token, self.tuf.log.len()); + } +} + +impl Iterator for DrainChangesMap<'_, '_, Atom, Elem, F> +where + F: FnMut(&Change) -> B, +{ + type Item = B; + + fn next(&mut self) -> Option { + self.inner.next().map(&mut self.f) + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl ExactSizeIterator for DrainChangesMap<'_, '_, Atom, Elem, F> where + F: FnMut(&Change) -> B +{ +} + +impl FusedIterator for DrainChangesMap<'_, '_, Atom, Elem, F> where + F: FnMut(&Change) -> B +{ +} diff --git a/union_find/src/union_find.rs b/union_find/src/union_find.rs new file mode 100644 index 0000000..8473d68 --- /dev/null +++ b/union_find/src/union_find.rs @@ -0,0 +1,241 @@ +//! `UnionFind` efficiently tracks equivalences between variables. +use crate::Element; +use atomic::Atomic; +use imctk_ids::{id_vec::IdVec, Id, IdRange}; +use std::sync::atomic::Ordering; + +#[cfg(test)] +#[path = "tests/test_union_find.rs"] +mod test_union_find; + +/// `UnionFind` efficiently tracks equivalences between variables. +/// +/// Given "elements" of type `Elem`, this structure keeps track of any known equivalences between these elements. +/// Equivalences are assumed to be *transitive*, i.e. if `x = y` and `y = z`, then `x = z` is assumed to be true, and in fact, +/// automatically discovered by this structure. +/// +/// Unlike a standard union find data structure, this version also keeps track of the *polarity* of elements. +/// For example, you might always have pairs of elements `+x` and `-x` that are exact opposites of each other. +/// If equivalences between `+x` and `-y` are then discovered, the structure also understands that `-x` and `+y` are similarly equivalent. +/// +/// An element without its polarity is called an *atom*. +/// The [`Element`](Element) trait is required on `Elem` so that this structure can relate elements, atoms and polarities. +/// +/// For each set of elements that are equivalent up to a polarity change, this structure keeps track of a *representative*. +/// Each element starts out as its own representative, and if two elements are declared equivalent, the representative of one becomes the representative of both. +/// The method `find` returns the representative for any element. +/// +/// To declare two elements as equivalent, use the `union` or `union_full` methods, see their documentation for details on their use. +/// +/// NB: Since this structure stores atoms in an `IdVec`, the atoms used should ideally be a contiguous set starting at `Atom::MIN_ID`. +/// +/// ## Example ## +/// ``` +/// use imctk_lit::{Var, Lit}; +/// use imctk_union_find::UnionFind; +/// +/// let mut union_find: UnionFind = UnionFind::new(); +/// let lit = |n| Var::from_index(n).as_lit(); +/// +/// assert_eq!(union_find.find(lit(4)), lit(4)); +/// +/// union_find.union([lit(3), lit(4)]); +/// assert_eq!(union_find.find(lit(4)), lit(3)); +/// +/// union_find.union([lit(1), !lit(2)]); +/// union_find.union([lit(2), lit(3)]); +/// assert_eq!(union_find.find(lit(1)), lit(1)); +/// assert_eq!(union_find.find(lit(4)), !lit(1)); +/// +/// ``` +pub struct UnionFind { + parent: IdVec>, +} + +impl Default for UnionFind { + fn default() -> Self { + UnionFind { + parent: Default::default(), + } + } +} + +impl Clone for UnionFind { + fn clone(&self) -> Self { + let new_parent = self + .parent + .values() + .iter() + .map(|p| Atomic::new(p.load(Ordering::Relaxed))) + .collect(); + Self { + parent: IdVec::from_vec(new_parent), + } + } +} + +impl UnionFind { + /// Constructs an empty `UnionFind`. + /// + /// The returned struct obeys `find(a) == a` for all `a`. + pub fn new() -> Self { + UnionFind::default() + } +} + +impl> UnionFind { + /// Constructs an empty `UnionFind`, with room for `capacity` elements. + pub fn with_capacity(capacity: usize) -> Self { + UnionFind { + parent: IdVec::from_vec(Vec::with_capacity(capacity)), + } + } + /// Clears all equivalences, but retains any allocated memory. + pub fn clear(&mut self) { + self.parent.clear(); + } + fn read_parent(&self, atom: Atom) -> Elem { + if let Some(parent_cell) = self.parent.get(atom) { + // This load is allowed to reorder with stores from `update_parent`, see there for details. + parent_cell.load(Ordering::Relaxed) + } else { + Elem::from_atom(atom) + } + } + // Important: Only semantically trivial changes are allowed using this method!! + // Specifically, update_parent(atom, parent) should only be called if `parent` is already an ancestor of `atom` + // Otherwise, concurrent calls to `read_parent` (which are explicitly allowed!) could return incorrect results. + fn update_parent(&self, atom: Atom, parent: Elem) { + if let Some(parent_cell) = self.parent.get(atom) { + parent_cell.store(parent, Ordering::Relaxed); + } else { + // can only get here if the precondition or a data structure invariant is violated + panic!("shouldn't happen: update_parent called with out of bounds argument"); + } + } + // Unlike `update_parent`, this is safe for arbitrary updates, since it requires &mut self. + fn write_parent(&mut self, atom: Atom, parent: Elem) { + if let Some(parent_cell) = self.parent.get(atom) { + parent_cell.store(parent, Ordering::Relaxed); + } else { + debug_assert!(self.parent.next_unused_key() <= atom); + while self.parent.next_unused_key() < atom { + let next_elem = Elem::from_atom(self.parent.next_unused_key()); + self.parent.push(Atomic::new(next_elem)); + } + self.parent.push(Atomic::new(parent)); + } + } + fn find_root(&self, mut elem: Elem) -> Elem { + loop { + // If we interleave with a call to `update_parent`, the parent may change + // under our feet, but it's okay because in that case we get an ancestor instead, + // which just skips some iterations of the loop! + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if elem == parent { + return elem; + } + debug_assert!(elem.atom() != parent.atom()); + elem = parent; + } + } + // Worst-case `find` performance is linear. To keep amortised time complexity logarithmic, + // we memoise the result of `find_root` by calling `update_parent` on every element + // we traversed. + fn update_root(&self, mut elem: Elem, root: Elem) { + // Loop invariant: `root` is the representative of `elem`. + loop { + // Like in `find_root`, this may interleave with `update_root` calls, and we may skip some steps, + // which is okay because the other thread will do the updates instead. + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if parent == root { + break; + } + // By the loop invariant, this just sets `elem`'s parent to its representative, + // which satisfies the precondition for `update_parent`. Further if two threads end up + // here simultaneously, they will both set to the same representative, + // therefore the change is idempotent. + self.update_parent(elem.atom(), root.apply_pol_of(elem)); + elem = parent; + } + } + /// Returns the representative for an element. Elements are equivalent iff they have the same representative. + /// + /// Elements `a` and `b` are equivalent up to a polarity change iff they obey `find(a) = find(b) ^ p` for some polarity `p`. + /// + /// This operation is guaranteed to return `elem` itself for arguments `elem >= lowest_unused_atom()`. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn find(&self, elem: Elem) -> Elem { + let root = self.find_root(elem); + self.update_root(elem, root); + root + } + /// Declares two elements to be equivalent. The new representative of both is the representative of the first element. + /// + /// If the elements are already equivalent or cannot be made equivalent (are equivalent up to a sign change), + /// the operation returns `false` without making any changes. Otherwise it returns `true`. + /// + /// In both cases it also returns the original representatives of both arguments. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let [a, b] = elems; + let ra = self.find(a); + let rb = self.find(b); + if ra.atom() == rb.atom() { + (false, [ra, rb]) + } else { + // The first write is only needed to ensure that the parent table actually contains `a` + // and is a no-op otherwise. + self.write_parent(ra.atom(), Elem::from_atom(ra.atom())); + self.write_parent(rb.atom(), ra.apply_pol_of(rb)); + (true, [ra, rb]) + } + } + /// Declares two elements to be equivalent. The new representative of both is the representative of the first element. + /// + /// If the elements are already equivalent or cannot be made equivalent (are equivalent up to polarity), + /// the operation returns `false` without making any changes. Otherwise it returns `true`. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn union(&mut self, elems: [Elem; 2]) -> bool { + self.union_full(elems).0 + } + /// Sets `atom` to be its own representative, and updates other representatives to preserve all existing equivalences. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn make_repr(&mut self, atom: Atom) -> Elem { + let root = self.find(Elem::from_atom(atom)); + self.write_parent(atom, Elem::from_atom(atom)); + self.write_parent(root.atom(), Elem::from_atom(atom).apply_pol_of(root)); + root + } + /// Returns the lowest `Atom` value for which no equivalences are known. + /// + /// It is guaranteed that `find(a) == a` if `a >= lowest_unused_atom`, but the converse may not hold. + pub fn lowest_unused_atom(&self) -> Atom { + self.parent.next_unused_key() + } + /// Returns an iterator that yields all tracked atoms and their representatives. + pub fn iter(&self) -> impl '_ + Iterator { + IdRange::from(Atom::MIN_ID..self.lowest_unused_atom()) + .iter() + .map(|atom| (atom, self.find(Elem::from_atom(atom)))) + } +} + +impl> std::fmt::Debug for UnionFind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // prints non-trivial sets of equivalent elements, always printing the representative first + let mut sets = std::collections::HashMap::>::new(); + for (atom, repr) in self.iter() { + if Elem::from_atom(atom) != repr { + sets.entry(repr.atom()) + .or_insert_with(|| vec![Elem::from_atom(repr.atom())]) + .push(Elem::from_atom(atom).apply_pol_of(repr)); + } + } + f.debug_set().entries(sets.values()).finish() + } +}