From e66f420260331dd598ca826bcb56257ed3dcd157 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Fri, 11 Oct 2024 17:49:35 +0100 Subject: [PATCH 1/7] add dsf crate --- Cargo.lock | 34 +++ Cargo.toml | 1 + dsf/Cargo.toml | 20 ++ dsf/src/lib.rs | 2 + dsf/src/tracked_union_find.rs | 474 ++++++++++++++++++++++++++++++++++ dsf/src/union_find.rs | 319 +++++++++++++++++++++++ ids/src/id_vec.rs | 2 +- lit/Cargo.toml | 1 + lit/src/lit.rs | 2 + lit/src/var.rs | 2 + 10 files changed, 856 insertions(+), 1 deletion(-) create mode 100644 dsf/Cargo.toml create mode 100644 dsf/src/lib.rs create mode 100644 dsf/src/tracked_union_find.rs create mode 100644 dsf/src/union_find.rs diff --git a/Cargo.lock b/Cargo.lock index 4a8b6f7..75eb901 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "atomic" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +dependencies = [ + "bytemuck", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -229,6 +238,19 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "dsf" +version = "0.1.0" +dependencies = [ + "atomic", + "bytemuck", + "imctk-ids", + "imctk-lit", + "priority-queue", + "rand", + "rand_pcg", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -470,6 +492,7 @@ dependencies = [ name = "imctk-lit" version = "0.1.0" dependencies = [ + "bytemuck", "flussab-aiger", "hashbrown 0.14.5", "imctk-derive", @@ -670,6 +693,17 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "priority-queue" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714c75db297bc88a63783ffc6ab9f830698a6705aa0201416931759ef4c8183d" +dependencies = [ + "autocfg", + "equivalent", + "indexmap", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 1fa671e..6ae2fc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "imctk", "lit", "stable_set", + "dsf", # comment to force multi-line layout ] diff --git a/dsf/Cargo.toml b/dsf/Cargo.toml new file mode 100644 index 0000000..bee1698 --- /dev/null +++ b/dsf/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "dsf" +version = "0.1.0" +edition = "2021" + +publish = false + +[dependencies] +imctk-ids = { version = "0.1.0", path = "../ids" } +imctk-lit = { version = "0.1.0", path = "../lit" } +bytemuck = "*" +atomic = "0.6" +priority-queue = "*" + +[dev-dependencies] +rand = "*" +rand_pcg = "*" + +[lints] +workspace = true diff --git a/dsf/src/lib.rs b/dsf/src/lib.rs new file mode 100644 index 0000000..0505e00 --- /dev/null +++ b/dsf/src/lib.rs @@ -0,0 +1,2 @@ +pub mod union_find; +pub mod tracked_union_find; \ No newline at end of file diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs new file mode 100644 index 0000000..0f985d4 --- /dev/null +++ b/dsf/src/tracked_union_find.rs @@ -0,0 +1,474 @@ +#![allow(missing_docs)] +#![allow(clippy::type_complexity)] +use std::{cmp::Reverse, collections::VecDeque, mem::ManuallyDrop, sync::Arc}; + +use atomic::Atomic; +use bytemuck::NoUninit; +use imctk_ids::{id_vec::IdVec, Id, Id64}; +use priority_queue::PriorityQueue; + +use crate::union_find::{Element, UnionFind}; + +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct TrackedUnionFindId(Id64); +// SAFETY: trust me bro +unsafe impl NoUninit for TrackedUnionFindId {} +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct Generation(u64); +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct ObserverId(Id64); +// SAFETY: trust me bro +unsafe impl NoUninit for ObserverId {} + +#[derive(Debug)] +pub struct ObserverToken { + tuf_id: TrackedUnionFindId, + generation: Generation, + observer_id: ObserverId, +} + +pub struct Renumbering { + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, +} + +impl std::fmt::Debug for Renumbering { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Renumbering") + .field("forward", &self.forward) + .field("reverse", &self.reverse) + .field("old_generation", &self.old_generation) + .field("new_generation", &self.new_generation) + .finish() + } +} + +impl + NoUninit> Renumbering { + pub fn get_reverse(forward: &IdVec>, union_find: &UnionFind) -> IdVec { + let mut reverse: IdVec> = IdVec::default(); + for (old, &new_opt) in forward { + if let Some(new) = new_opt { + reverse + .grow_for_key(new.atom()) + .replace(union_find.find(Elem::from_atom(old)).apply_pol_of(new)); + } + } + IdVec::from_vec(reverse.iter().map(|x| x.1.unwrap()).collect()) + } + pub fn is_inverse(forward: &IdVec>, reverse: &IdVec) -> bool { + reverse.iter().all(|(new, &old)| { + if let Some(&Some(e)) = forward.get(old.atom()) { + Elem::from_atom(new) == e.apply_pol_of(old) + } else { + false + } + }) + } + pub fn new_unchecked( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + Renumbering { + forward, + reverse, + old_generation, + new_generation, + } + } + pub fn new( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + debug_assert!(new_generation > old_generation); + debug_assert!(Self::is_inverse(&forward, &reverse)); + Self::new_unchecked(forward, reverse, old_generation, new_generation) + } + pub fn old_to_new(&self, old: Elem) -> Option { + self.forward + .get(old.atom()) + .copied() + .flatten() + .map(|e| e.apply_pol_of(old)) + } + pub fn new_to_old(&self, new: Elem) -> Option { + self.reverse.get(new.atom()).map(|&e| e.apply_pol_of(new)) + } + pub fn is_repr_reduction(&self, union_find: &UnionFind) -> bool { + union_find.lowest_unused_atom() <= self.forward.next_unused_key() + && self.forward.iter().all(|(old, &new)| { + let repr = union_find.find(Elem::from_atom(old)); + let repr_new = self.old_to_new(repr); + repr_new == new + }) + } +} + +#[derive(Clone)] +pub enum Change { + Union { new_repr: Atom, merged_repr: Elem }, + MakeRepr { new_repr: Atom, old_repr: Elem }, + Renumber(Arc>), +} + +impl std::fmt::Debug for Change { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Union { + new_repr, + merged_repr, + } => f + .debug_struct("Union") + .field("new_repr", new_repr) + .field("merged_repr", merged_repr) + .finish(), + Self::MakeRepr { new_repr, old_repr } => f + .debug_struct("MakeRepr") + .field("new_repr", new_repr) + .field("old_repr", old_repr) + .finish(), + Self::Renumber(arg0) => f.debug_tuple("Renumber").field(arg0).finish(), + } + } +} + +pub struct TrackedUnionFind { + tuf_id: TrackedUnionFindId, + union_find: UnionFind, + log: VecDeque>, + observer_id_alloc: IdAlloc, + observers: PriorityQueue>, + log_start: u64, + generation: Generation, +} + +pub struct IdAlloc { + counter: Atomic, +} + +impl Default for IdAlloc { + fn default() -> Self { + Self::new() + } +} + +impl IdAlloc { + const fn new() -> Self { + Self { + counter: Atomic::new(T::MIN_ID), + } + } + pub fn alloc_block(&self, n: usize) -> T { + use atomic::Ordering::Relaxed; + debug_assert!(n > 0); + let mut current_id = self.counter.load(Relaxed); + loop { + let new_id = current_id + .id_index() + .checked_add(n) + .and_then(T::try_from_id_index) + .expect("not enough IDs remaining"); + match self + .counter + .compare_exchange_weak(current_id, new_id, Relaxed, Relaxed) + { + Ok(_) => return current_id, + Err(id) => current_id = id, + } + } + } + pub fn alloc(&self) -> T { + self.alloc_block(1) + } +} + +static TUF_ID_ALLOC: IdAlloc = IdAlloc::new(); + +impl Default for TrackedUnionFind { + fn default() -> Self { + Self { + tuf_id: TUF_ID_ALLOC.alloc(), + union_find: Default::default(), + log: Default::default(), + observer_id_alloc: Default::default(), + observers: Default::default(), + log_start: 0, + generation: Generation(0), + } + } +} + +impl + NoUninit> TrackedUnionFind { + pub fn find(&self, elem: Elem) -> Elem { + self.union_find.find(elem) + } + pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, roots) = self.union_find.union_full(lits); + if ok && !self.observers.is_empty() { + let new_repr = roots[0].atom(); + let merged_repr = roots[1].apply_pol_of(roots[0]); + self.log.push_back(Change::Union { + new_repr, + merged_repr, + }) + } + (ok, roots) + } + pub fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + pub fn make_repr(&mut self, new_repr: Atom) -> Elem { + let old_repr = self.union_find.make_repr(new_repr); + if old_repr.atom() != new_repr && !self.observers.is_empty() { + self.log.push_back(Change::MakeRepr { new_repr, old_repr }); + } + old_repr + } + pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) { + let old_generation = self.generation; + let new_generation = Generation(old_generation.0 + 1); + self.generation = new_generation; + let renumbering = Renumbering::new(forward, reverse, old_generation, new_generation); + debug_assert!(renumbering.is_repr_reduction(&self.union_find)); + if !self.observers.is_empty() { + self.log.push_back(Change::Renumber(Arc::new(renumbering))); + } + self.union_find = UnionFind::new(); + } +} + +impl TrackedUnionFind { + pub fn new() -> Self { + Self::default() + } + fn log_end(&self) -> u64 { + self.log_start + self.log.len() as u64 + } + pub fn start_observing(&mut self) -> ObserverToken { + let observer_id = self.observer_id_alloc.alloc(); + self.observers.push(observer_id, Reverse(self.log_end())); + ObserverToken { + tuf_id: self.tuf_id, + generation: self.generation, + observer_id, + } + } + pub fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { + assert!(token.tuf_id == self.tuf_id); + let new_observer_id = self.observer_id_alloc.alloc(); + let pos = *self.observers.get_priority(&token.observer_id).unwrap(); + self.observers.push(new_observer_id, pos); + ObserverToken { + tuf_id: self.tuf_id, + generation: token.generation, + observer_id: new_observer_id, + } + } + pub fn stop_observing(&mut self, token: ObserverToken) { + assert!(token.tuf_id == self.tuf_id); + self.observers.remove(&token.observer_id); + self.truncate_log(); + } + fn truncate_log(&mut self) { + if let Some((_, &Reverse(new_start))) = self.observers.peek() { + if new_start > self.log_start { + let delete = (new_start - self.log_start).try_into().unwrap(); + drop(self.log.drain(0..delete)); + println!("dropped {delete} entries"); + self.log_start = new_start; + } + } else { + self.log_start = self.log_end(); + self.log.clear(); + println!("dropped all entries"); + } + } + fn observer_rel_pos(&self, token: &ObserverToken) -> usize { + assert!(token.tuf_id == self.tuf_id); + let abs_pos = self.observers.get_priority(&token.observer_id).unwrap().0; + debug_assert!(abs_pos >= self.log_start); + (abs_pos - self.log_start).try_into().unwrap() + } + fn observer_inc_pos(&mut self, token: &ObserverToken, by: u64) { + let (min, max) = (self.log_start, self.log_end()); + self.observers + .change_priority_by(&token.observer_id, |pos| { + let new = pos.0 + by; + debug_assert!(new >= min && new <= max); + *pos = Reverse(new); + }); + self.truncate_log(); + } + fn observer_set_rel_pos(&mut self, token: &ObserverToken, rel_pos: usize) { + debug_assert!(rel_pos <= self.log.len()); + let abs_pos = self.log_start + rel_pos as u64; + self.observers + .change_priority(&token.observer_id, Reverse(abs_pos)); + self.truncate_log(); + } + fn change_slices( + &self, + token: &ObserverToken, + ) -> ( + &[Change], + &[Change], + &UnionFind, + ) { + let rel_pos = self.observer_rel_pos(token); + let (first, second) = self.log.as_slices(); + if rel_pos >= first.len() { + (&second[rel_pos - first.len()..], &[], &self.union_find) + } else { + (&first[rel_pos..], second, &self.union_find) + } + } + fn observer_has_seen(&self, token: &mut ObserverToken, changes: &[Change]) { + for change in changes { + if let Change::Renumber(renumbering) = change { + debug_assert!(token.generation == renumbering.old_generation); + token.generation = renumbering.new_generation; + } + } + } + pub fn drain_changes_with_fn( + &mut self, + token: &mut ObserverToken, + mut f: impl FnMut(&[Change], &UnionFind), + ) -> bool { + let (first, second, union_find) = self.change_slices(token); + if !first.is_empty() { + f(first, union_find); + self.observer_has_seen(token, first); + if !second.is_empty() { + f(second, union_find); + self.observer_has_seen(token, second); + } + let drained = first.len() + second.len(); + self.observer_inc_pos(token, drained as u64); + true + } else { + false + } + } + pub fn drain_changes<'a>( + &mut self, + token: &'a mut ObserverToken, + ) -> DrainChanges<'_, 'a, Atom, Elem> { + assert!(token.tuf_id == self.tuf_id); + let rel_pos = self.observer_rel_pos(token); + DrainChanges { + tuf: self, + token, + rel_pos, + } + } +} + +pub struct DrainChanges<'a, 'b, Atom, Elem> { + tuf: &'a mut TrackedUnionFind, + token: &'b mut ObserverToken, + rel_pos: usize, +} + +pub struct DrainChangesMap<'a, 'b, Atom, Elem, F> { + inner: DrainChanges<'a, 'b, Atom, Elem>, + f: F, +} + +impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { + pub fn peek(&mut self) -> Option<&Change> { + self.tuf.log.get(self.rel_pos) + } + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<&Change> { + let ret = self.tuf.log.get(self.rel_pos); + if let Some(change) = ret { + self.tuf + .observer_has_seen(self.token, std::slice::from_ref(change)); + self.rel_pos += 1; + } + ret + } + pub fn stop(self) { + self.tuf.observer_set_rel_pos(self.token, self.rel_pos); + let _ = ManuallyDrop::new(self); + } + pub fn size_hint(&self) -> (usize, Option) { + let count = self.tuf.log.len() - self.rel_pos; + (count, Some(count)) + } + pub fn map(self, f: F) -> DrainChangesMap<'a, 'b, Atom, Elem, F> + where + F: FnMut(&Change) -> B, + { + DrainChangesMap { inner: self, f } + } +} + +impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { + pub fn cloned( + self, + ) -> DrainChangesMap<'a, 'b, Atom, Elem, fn(&Change) -> Change> { + self.map(|x| x.clone()) + } +} + +impl Drop for DrainChanges<'_, '_, Atom, Elem> { + fn drop(&mut self) { + self.tuf + .observer_set_rel_pos(self.token, self.tuf.log.len()); + } +} + +impl Iterator for DrainChangesMap<'_, '_, Atom, Elem, F> +where + F: FnMut(&Change) -> B, +{ + type Item = B; + + fn next(&mut self) -> Option { + self.inner.next().map(&mut self.f) + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[test] +fn test() { + use imctk_lit::{Lit, Var}; + let l = |n| Var::from_index(n).as_lit(); + let mut tuf = TrackedUnionFind::::new(); + let mut token = tuf.start_observing(); + tuf.union([l(3), !l(4)]); + tuf.union([l(8), l(7)]); + let mut token2 = tuf.start_observing(); + tuf.union([l(4), l(5)]); + for change in tuf.drain_changes(&mut token).cloned() { + println!("{change:?}"); + } + println!("---"); + tuf.union([!l(5), l(6)]); + tuf.make_repr(l(4).var()); + let renumber: IdVec> = + IdVec::from_vec(vec![Some(l(0)), None, None, Some(l(1)), Some(!l(1)), Some(!l(1)), Some(l(1)), Some(l(2)), Some(l(2))]); + let reverse = Renumbering::get_reverse(&renumber, &tuf.union_find); + dbg!(&renumber, &reverse); + tuf.renumber(renumber, reverse); + tuf.union([l(0), l(1)]); + let mut iter = tuf.drain_changes(&mut token); + println!("{:?}", iter.next()); + iter.stop(); + println!("---"); + for change in tuf.drain_changes(&mut token2).cloned() { + println!("{change:?}"); + } +} diff --git a/dsf/src/union_find.rs b/dsf/src/union_find.rs new file mode 100644 index 0000000..5036432 --- /dev/null +++ b/dsf/src/union_find.rs @@ -0,0 +1,319 @@ +#![allow(missing_docs)] +use std::sync::atomic::Ordering; + +use atomic::Atomic; +use bytemuck::NoUninit; +use imctk_ids::{id_vec::IdVec, Id}; +use imctk_lit::{Lit, Var}; + +pub trait Element { + fn from_atom(atom: Atom) -> Self; + fn atom(self) -> Atom; + fn apply_pol_of(self, other: Self) -> Self; +} + +impl Element for T { + fn from_atom(atom: T) -> Self { + atom + } + fn atom(self) -> T { + self + } + fn apply_pol_of(self, _other: T) -> Self { + self + } +} + +impl Element for Lit { + fn from_atom(atom: Var) -> Self { + atom.as_lit() + } + fn atom(self) -> Var { + self.var() + } + fn apply_pol_of(self, other: Self) -> Self { + self ^ other.pol() + } +} + +pub struct UnionFind { + parent: IdVec>, +} + +impl Default for UnionFind { + fn default() -> Self { + UnionFind { + parent: Default::default(), + } + } +} + +impl Clone for UnionFind { + fn clone(&self) -> Self { + let new_parent = self + .parent + .values() + .iter() + .map(|p| Atomic::new(p.load(Ordering::Relaxed))) + .collect(); + Self { + parent: IdVec::from_vec(new_parent), + } + } +} + +impl + NoUninit> UnionFind { + pub fn new() -> Self { + UnionFind::default() + } + fn read_parent(&self, atom: Atom) -> Elem { + self.parent + .get(atom) + .map(|p| p.load(Ordering::Relaxed)) + .unwrap_or(Elem::from_atom(atom)) + } + fn update_parent(&self, atom: Atom, parent: Elem) { + let Some(parent_ref) = self.parent.get(atom) else { + panic!("shouldn't happen: update_parent called with out of bounds argument"); + }; + parent_ref.store(parent, Ordering::Relaxed); + } + fn write_parent(&mut self, atom: Atom, parent: Elem) { + if let Some(parent_ref) = self.parent.get(atom) { + parent_ref.store(parent, Ordering::Relaxed); + } else { + debug_assert!(self.parent.next_unused_key() <= atom); + while self.parent.next_unused_key() < atom { + self.parent + .push(Atomic::new(Elem::from_atom(self.parent.next_unused_key()))); + } + self.parent.push(Atomic::new(parent)); + } + } + fn find_root(&self, mut elem: Elem) -> Elem { + loop { + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if elem == parent { + return elem; + } + debug_assert!(elem.atom() != parent.atom()); + elem = parent; + } + } + fn update_root(&self, mut elem: Elem, root: Elem) { + loop { + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if parent == root { + break; + } + self.update_parent(elem.atom(), root.apply_pol_of(elem)); + elem = parent; + } + } + pub fn find(&self, lit: Elem) -> Elem { + let root = self.find_root(lit); + self.update_root(lit, root); + root + } + pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let [a, b] = lits; + let ra = self.find(a); + let rb = self.find(b); + if ra.atom() == rb.atom() { + (false, [ra, rb]) + } else { + self.write_parent(rb.atom(), ra.apply_pol_of(rb)); + (true, [ra, rb]) + } + } + pub fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + pub fn make_repr(&mut self, atom: Atom) -> Elem { + let root = self.find(Elem::from_atom(atom)); + self.write_parent(atom, Elem::from_atom(atom)); + self.write_parent(root.atom(), Elem::from_atom(atom).apply_pol_of(root)); + root + } + pub fn lowest_unused_atom(&self) -> Atom { + self.parent.next_unused_key() + } +} + +#[cfg(test)] +#[allow(dead_code)] +mod tests { + use super::*; + use imctk_ids::id_set_seq::IdSetSeq; + use rand::prelude::*; + use std::collections::{HashSet, VecDeque}; + + #[derive(Default)] + struct CheckedUnionFind { + dut: UnionFind, + equivs: IdSetSeq, + } + + impl + NoUninit> UnionFind { + fn debug_print_tree( + children: &IdVec>, + atom: Atom, + prefix: &str, + self_char: &str, + further_char: &str, + pol: bool, + ) { + println!( + "{prefix}{self_char}{}{:?}", + if pol { "!" } else { "" }, + atom + ); + let my_children = children.get(atom).unwrap(); + for (index, &child) in my_children.iter().enumerate() { + let last = index == my_children.len() - 1; + let self_char = if last { "└" } else { "├" }; + let next_further_char = if last { " " } else { "│" }; + Self::debug_print_tree( + children, + child.atom(), + &(prefix.to_string() + further_char), + self_char, + next_further_char, + pol ^ (child != Elem::from_atom(child.atom())), + ); + } + } + fn debug_print(&self) { + let mut children: IdVec> = Default::default(); + for atom in self.parent.keys() { + let parent = self.read_parent(atom); + children.grow_for_key(atom); + if atom != parent.atom() { + children + .grow_for_key(parent.atom()) + .push(Elem::from_atom(atom).apply_pol_of(parent)); + } else { + assert!(Elem::from_atom(atom) == parent); + } + } + for atom in self.parent.keys() { + if atom == self.read_parent(atom).atom() { + Self::debug_print_tree(&children, atom, "", "", " ", false); + } + } + } + } + #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] + enum VarRel { + Equiv, + AntiEquiv, + NotEquiv, + } + + impl + NoUninit> CheckedUnionFind { + fn new() -> Self { + CheckedUnionFind { + dut: Default::default(), + equivs: Default::default(), + } + } + fn ref_equal(&mut self, start: Elem, goal: Elem) -> VarRel { + let mut seen: HashSet = Default::default(); + let mut queue: VecDeque = [start].into(); + while let Some(place) = queue.pop_front() { + if place.atom() == goal.atom() { + if place == goal { + return VarRel::Equiv; + } else { + return VarRel::AntiEquiv; + } + } + seen.insert(place.atom()); + for &next in self.equivs.grow_for(place.atom()).iter() { + if !seen.contains(&next.atom()) { + queue.push_back(next.apply_pol_of(place)); + } + } + } + VarRel::NotEquiv + } + fn find(&mut self, lit: Elem) -> Elem { + let out = self.dut.find(lit); + assert!(self.ref_equal(lit, out) == VarRel::Equiv); + out + } + fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, [ra, rb]) = self.dut.union_full(lits); + assert_eq!(self.ref_equal(lits[0], ra), VarRel::Equiv); + assert_eq!(self.ref_equal(lits[1], rb), VarRel::Equiv); + assert_eq!(ok, self.ref_equal(lits[0], lits[1]) == VarRel::NotEquiv); + assert_eq!(self.dut.find_root(lits[0]), ra); + if ok { + assert_eq!(self.dut.find_root(lits[1]), ra); + self.equivs + .grow_for(lits[0].atom()) + .insert(lits[1].apply_pol_of(lits[0])); + self.equivs + .grow_for(lits[1].atom()) + .insert(lits[0].apply_pol_of(lits[1])); + } else { + assert_eq!(self.dut.find_root(lits[1]).atom(), ra.atom()); + } + (ok, [ra, rb]) + } + fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + fn make_repr(&mut self, lit: Atom) { + self.dut.make_repr(lit); + assert_eq!( + self.dut.find_root(Elem::from_atom(lit)), + Elem::from_atom(lit) + ); + self.check(); + } + fn check(&mut self) { + for atom in self.dut.parent.keys() { + let parent = self.dut.read_parent(atom); + assert_eq!(self.ref_equal(Elem::from_atom(atom), parent), VarRel::Equiv); + let root = self.dut.find_root(Elem::from_atom(atom)); + for &child in self.equivs.grow_for(atom).iter() { + assert_eq!(root, self.dut.find_root(child)); + } + } + } + } + + #[test] + fn test() { + let mut u: CheckedUnionFind = CheckedUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let max_var = 2000; + for i in 0..2000 { + match rng.gen_range(0..10) { + 0..=4 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + println!("union({a}, {b}) = {result:?}"); + } + 5..=7 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + println!("find({a}) = {result}"); + } + 8 => { + u.check(); + } + 9 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + println!("make_repr({a})"); + } + _ => {} + } + } + u.check(); + //u.dut.debug_print(); + } +} diff --git a/ids/src/id_vec.rs b/ids/src/id_vec.rs index ff04a19..9c40049 100644 --- a/ids/src/id_vec.rs +++ b/ids/src/id_vec.rs @@ -312,7 +312,7 @@ impl Clone for IdVec { } } -impl Default for IdVec { +impl Default for IdVec { #[inline(always)] fn default() -> Self { Self { diff --git a/lit/Cargo.toml b/lit/Cargo.toml index 925485a..1b29ee4 100644 --- a/lit/Cargo.toml +++ b/lit/Cargo.toml @@ -19,3 +19,4 @@ imctk-util = { version = "0.1.0", path = "../util" } log = "0.4.21" table_seq = { version = "0.1.0", path = "../table_seq" } zwohash = "0.1.2" +bytemuck = "*" \ No newline at end of file diff --git a/lit/src/lit.rs b/lit/src/lit.rs index 409869c..0aa1a31 100644 --- a/lit/src/lit.rs +++ b/lit/src/lit.rs @@ -24,6 +24,8 @@ use super::{pol::Pol, var::Var}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Lit(Id32); +unsafe impl bytemuck::NoUninit for Lit {} + /// Ensure that there is an even number of literals #[allow(clippy::assertions_on_constants)] const _: () = { diff --git a/lit/src/var.rs b/lit/src/var.rs index ce9fc3a..473ebc4 100644 --- a/lit/src/var.rs +++ b/lit/src/var.rs @@ -9,6 +9,8 @@ use super::{lit::Lit, pol::Pol}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Var(GenericId<{ Lit::MAX_ID_INDEX / 2 }, ::BaseId>); +unsafe impl bytemuck::NoUninit for Var {} + impl std::fmt::Debug for Var { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) From 4df22e77a6e1b505833c8acbdbddb63cb88ff023 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Tue, 15 Oct 2024 07:57:23 +0100 Subject: [PATCH 2/7] dsf: use atomic fetch update --- dsf/src/tracked_union_find.rs | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs index 0f985d4..81f47d0 100644 --- a/dsf/src/tracked_union_find.rs +++ b/dsf/src/tracked_union_find.rs @@ -169,21 +169,14 @@ impl IdAlloc { pub fn alloc_block(&self, n: usize) -> T { use atomic::Ordering::Relaxed; debug_assert!(n > 0); - let mut current_id = self.counter.load(Relaxed); - loop { - let new_id = current_id - .id_index() - .checked_add(n) - .and_then(T::try_from_id_index) - .expect("not enough IDs remaining"); - match self - .counter - .compare_exchange_weak(current_id, new_id, Relaxed, Relaxed) - { - Ok(_) => return current_id, - Err(id) => current_id = id, - } - } + self.counter + .fetch_update(Relaxed, Relaxed, |current_id| { + current_id + .id_index() + .checked_add(n) + .and_then(T::try_from_id_index) + }) + .expect("not enough IDs remaining") } pub fn alloc(&self) -> T { self.alloc_block(1) From cb9e0ec4968c0056dd7d2d35948cf65b14688ada Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Tue, 15 Oct 2024 10:02:10 +0100 Subject: [PATCH 3/7] add bytemuck::NoUninit trait to all Id types --- Cargo.lock | 22 ++++++++++++++++++---- derive/src/id.rs | 5 +++++ ids/Cargo.toml | 1 + ids/src/id.rs | 9 ++++++++- ids/src/id/id_types.rs | 28 ++++++++++++++-------------- ids/src/id/u8_range_types.rs | 8 +++++--- ids/src/lib.rs | 4 ++++ ids/tests/test_id.rs | 2 +- lit/Cargo.toml | 3 +-- lit/src/lit.rs | 2 -- lit/src/var.rs | 2 -- 11 files changed, 57 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75eb901..98b0d04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,9 +119,23 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "byteorder" @@ -455,7 +469,8 @@ dependencies = [ name = "imctk-ids" version = "0.1.0" dependencies = [ - "hashbrown 0.14.5", + "bytemuck", + "hashbrown", "imctk-derive", "imctk-transparent", "rand", @@ -492,7 +507,6 @@ dependencies = [ name = "imctk-lit" version = "0.1.0" dependencies = [ - "bytemuck", "flussab-aiger", "hashbrown 0.14.5", "imctk-derive", diff --git a/derive/src/id.rs b/derive/src/id.rs index fd76a6c..944332e 100644 --- a/derive/src/id.rs +++ b/derive/src/id.rs @@ -28,6 +28,7 @@ pub fn derive_id(input: DeriveInput, internal_generic_id: bool) -> syn::Result syn::Result { /// /// Users of this trait may depend on implementing types following these requirements for upholding /// their own safety invariants. -pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug { +pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug + NoUninit { /// An [`Id`] type that has the same representation and index range as this type. /// /// This is provided to enable writing generic code that during monomorphization is only @@ -161,6 +162,12 @@ pub unsafe trait Id: Copy + Ord + Hash + Send + Sync + Debug { #[repr(transparent)] pub struct GenericId(Repr); +// SAFETY: #[repr(transparent)] and the only field is explicitly required to be NoUninit +unsafe impl NoUninit for GenericId where + Repr: Id + NoUninit +{ +} + impl Debug for GenericId { #[inline(always)] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/ids/src/id/id_types.rs b/ids/src/id/id_types.rs index 1240ef6..8dfb1ec 100644 --- a/ids/src/id/id_types.rs +++ b/ids/src/id/id_types.rs @@ -1,12 +1,12 @@ mod id8 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxHighNibbleU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxHighNibbleU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xf0`. #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(transparent)] pub struct Id8(NonMaxHighNibbleU8); @@ -177,13 +177,13 @@ mod id8 { mod id16 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(2))] pub struct Id16 { lsb: u8, @@ -192,7 +192,7 @@ mod id16 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(2))] pub struct Id16 { msb: NonMaxU8, @@ -366,13 +366,13 @@ mod id16 { mod id32 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00_0000`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(4))] pub struct Id32 { lsbs: [u8; 3], @@ -381,7 +381,7 @@ mod id32 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(4))] pub struct Id32 { msb: NonMaxU8, @@ -562,13 +562,13 @@ mod id32 { mod id64 { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; /// [`Id`] type representing indices in the range `0..0xff00_0000_0000_0000`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(8))] pub struct Id64 { lsbs: [u8; 7], @@ -577,7 +577,7 @@ mod id64 { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[repr(C, align(8))] pub struct Id64 { msb: NonMaxU8, @@ -758,7 +758,7 @@ mod id64 { mod id_size { use imctk_transparent::SubtypeCast; - use crate::id::{u8_range_types::NonMaxMsbU8, ConstIdFromIdIndex, GenericId, Id}; + use crate::id::{u8_range_types::NonMaxMsbU8, ConstIdFromIdIndex, GenericId, Id, NoUninit}; use core::{fmt, fmt::Debug, hash::Hash}; const LSBS: usize = (usize::BITS as usize / 8) - 1; @@ -766,7 +766,7 @@ mod id_size { /// [`Id`] type representing indices in the range `0..=isize::MAX as usize`. #[cfg(target_endian = "little")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[cfg_attr(target_pointer_width = "16", repr(C, align(2)))] #[cfg_attr(target_pointer_width = "32", repr(C, align(4)))] #[cfg_attr(target_pointer_width = "64", repr(C, align(8)))] @@ -777,7 +777,7 @@ mod id_size { #[cfg(target_endian = "big")] #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts - #[derive(Clone, Copy)] + #[derive(Clone, Copy, NoUninit)] #[cfg_attr(target_pointer_width = "16", repr(C, align(2)))] #[cfg_attr(target_pointer_width = "32", repr(C, align(4)))] #[cfg_attr(target_pointer_width = "64", repr(C, align(8)))] diff --git a/ids/src/id/u8_range_types.rs b/ids/src/id/u8_range_types.rs index e5ba744..69a607f 100644 --- a/ids/src/id/u8_range_types.rs +++ b/ids/src/id/u8_range_types.rs @@ -1,5 +1,7 @@ +use crate::NoUninit; + #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxU8 { Val00 = 0x00, @@ -260,7 +262,7 @@ pub enum NonMaxU8 { } #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxHighNibbleU8 { Val00 = 0x00, @@ -506,7 +508,7 @@ pub enum NonMaxHighNibbleU8 { } #[allow(dead_code)] // Only constructed via transmutation and/or pointer casts -#[derive(Clone, Copy)] +#[derive(Clone, Copy, NoUninit)] #[repr(u8)] pub enum NonMaxMsbU8 { Val00 = 0x00, diff --git a/ids/src/lib.rs b/ids/src/lib.rs index 931307a..a65669b 100644 --- a/ids/src/lib.rs +++ b/ids/src/lib.rs @@ -31,3 +31,7 @@ pub use imctk_derive::Id; pub use id::{ConstIdFromIdIndex, GenericId, Id, Id16, Id32, Id64, Id8, IdSize}; pub use id_range::IdRange; + +// re-export this so that others can use it without depending on bytemuck explicitly +// in particular needed for #[derive(Id)] +pub use bytemuck::NoUninit; \ No newline at end of file diff --git a/ids/tests/test_id.rs b/ids/tests/test_id.rs index 6df140f..0292768 100644 --- a/ids/tests/test_id.rs +++ b/ids/tests/test_id.rs @@ -383,7 +383,7 @@ fn conversion_usize() { #[derive(Id)] #[repr(transparent)] -pub struct NewtypePhantom(usize, PhantomData); +pub struct NewtypePhantom(usize, PhantomData); impl std::fmt::Debug for NewtypePhantom { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/lit/Cargo.toml b/lit/Cargo.toml index 1b29ee4..f9b5cf0 100644 --- a/lit/Cargo.toml +++ b/lit/Cargo.toml @@ -18,5 +18,4 @@ imctk-transparent = { version = "0.1.0", path = "../transparent" } imctk-util = { version = "0.1.0", path = "../util" } log = "0.4.21" table_seq = { version = "0.1.0", path = "../table_seq" } -zwohash = "0.1.2" -bytemuck = "*" \ No newline at end of file +zwohash = "0.1.2" \ No newline at end of file diff --git a/lit/src/lit.rs b/lit/src/lit.rs index 0aa1a31..409869c 100644 --- a/lit/src/lit.rs +++ b/lit/src/lit.rs @@ -24,8 +24,6 @@ use super::{pol::Pol, var::Var}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Lit(Id32); -unsafe impl bytemuck::NoUninit for Lit {} - /// Ensure that there is an even number of literals #[allow(clippy::assertions_on_constants)] const _: () = { diff --git a/lit/src/var.rs b/lit/src/var.rs index 473ebc4..ce9fc3a 100644 --- a/lit/src/var.rs +++ b/lit/src/var.rs @@ -9,8 +9,6 @@ use super::{lit::Lit, pol::Pol}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Var(GenericId<{ Lit::MAX_ID_INDEX / 2 }, ::BaseId>); -unsafe impl bytemuck::NoUninit for Var {} - impl std::fmt::Debug for Var { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) From 5192375f140b6e1e73ef1218c0e0d91908077606 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Tue, 15 Oct 2024 18:11:24 +0100 Subject: [PATCH 4/7] clean up dsf crate; move IdAlloc to ids crate --- Cargo.lock | 1 - dsf/Cargo.toml | 1 - dsf/src/element.rs | 57 ++++ dsf/src/lib.rs | 13 +- dsf/src/tests/test_tracked_union_find.rs | 32 ++ dsf/src/tests/test_union_find.rs | 176 +++++++++++ dsf/src/tracked_union_find.rs | 256 +++++++++------ dsf/src/union_find.rs | 378 +++++++++-------------- ids/src/id_alloc.rs | 54 ++++ ids/src/lib.rs | 3 + 10 files changed, 654 insertions(+), 317 deletions(-) create mode 100644 dsf/src/element.rs create mode 100644 dsf/src/tests/test_tracked_union_find.rs create mode 100644 dsf/src/tests/test_union_find.rs create mode 100644 ids/src/id_alloc.rs diff --git a/Cargo.lock b/Cargo.lock index 98b0d04..0da409e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,7 +257,6 @@ name = "dsf" version = "0.1.0" dependencies = [ "atomic", - "bytemuck", "imctk-ids", "imctk-lit", "priority-queue", diff --git a/dsf/Cargo.toml b/dsf/Cargo.toml index bee1698..f4fd701 100644 --- a/dsf/Cargo.toml +++ b/dsf/Cargo.toml @@ -8,7 +8,6 @@ publish = false [dependencies] imctk-ids = { version = "0.1.0", path = "../ids" } imctk-lit = { version = "0.1.0", path = "../lit" } -bytemuck = "*" atomic = "0.6" priority-queue = "*" diff --git a/dsf/src/element.rs b/dsf/src/element.rs new file mode 100644 index 0000000..7bb6809 --- /dev/null +++ b/dsf/src/element.rs @@ -0,0 +1,57 @@ +//! A trait for "elements" that can be split into an "atom" and a "polarity". + +use imctk_lit::{Lit, Var}; + +/// A trait for "elements" that can be split into an "atom" and a "polarity". +/// +/// This lets code generically manipulate variables and literals, and further serves to abstract over their concrete representation. +/// +/// The two most common case this trait is used for are: +/// 1) The element and the atom are both `Var`. In this case there is only one polarity and the trait implementation is trivial. +/// 2) The element is `Lit` and the atom is `Var`. Here there are two polarities (`+` and `-`) to keep track of. +/// +/// Mathematically, implementing this trait signifies that elements can be written as pairs `(a, p)` with an atom `a` and a polarity `p`. +/// The polarities are assumed to form a group `(P, *, 1)`. The trait operations then correspond to: +/// 1) `from_atom(a) = (a, 1)` +/// 2) `atom((a, p)) = a` +/// 3) `apply_pol_of((a, p), (b, q)) = (a, p * q)` +/// +/// Currently, code assumes that `P` is either trivial or isomorphic to `Z_2`. +/// +/// Code using this trait may assume the following axioms to hold: +/// 1) `from_atom(atom(x)) == x` +/// 2) `apply_pol_of(atom(x), x) == x` +/// 3) `apply_pol_of(apply_pol_of(x, y), y) == x` +// TODO: add missing axioms +pub trait Element { + /// Constructs an element from an atom by applying positive polarity. + fn from_atom(atom: Atom) -> Self; + /// Returns the atom corresponding to an element. + fn atom(self) -> Atom; + /// Multiplies `self` by the polarity of `other`, i.e. conceptually `apply_pol_of(self, other) = self ^ pol(other)`. + fn apply_pol_of(self, other: Self) -> Self; +} + +impl Element for T { + fn from_atom(atom: T) -> Self { + atom + } + fn atom(self) -> T { + self + } + fn apply_pol_of(self, _other: T) -> Self { + self + } +} + +impl Element for Lit { + fn from_atom(atom: Var) -> Self { + atom.as_lit() + } + fn atom(self) -> Var { + self.var() + } + fn apply_pol_of(self, other: Self) -> Self { + self ^ other.pol() + } +} \ No newline at end of file diff --git a/dsf/src/lib.rs b/dsf/src/lib.rs index 0505e00..f4f90ca 100644 --- a/dsf/src/lib.rs +++ b/dsf/src/lib.rs @@ -1,2 +1,13 @@ +//! This crate defines a structure [`UnionFind`] that allows tracking of equivalences between generic elements +//! and a structure [`TrackedUnionFind`] that provides the same functionality but augmented by change tracking. + +#[doc(inline)] +pub use element::Element; +#[doc(inline)] +pub use tracked_union_find::TrackedUnionFind; +#[doc(inline)] +pub use union_find::UnionFind; + +pub mod element; +pub mod tracked_union_find; pub mod union_find; -pub mod tracked_union_find; \ No newline at end of file diff --git a/dsf/src/tests/test_tracked_union_find.rs b/dsf/src/tests/test_tracked_union_find.rs new file mode 100644 index 0000000..24f8d5c --- /dev/null +++ b/dsf/src/tests/test_tracked_union_find.rs @@ -0,0 +1,32 @@ +use super::*; +use imctk_lit::{Lit, Var}; + +#[test] +fn test() { + let l = |n| Var::from_index(n).as_lit(); + let mut tuf = TrackedUnionFind::::new(); + let mut token = tuf.start_observing(); + tuf.union([l(3), !l(4)]); + tuf.union([l(8), l(7)]); + let mut token2 = tuf.start_observing(); + tuf.union([l(4), l(5)]); + for change in tuf.drain_changes(&mut token).cloned() { + println!("{change:?}"); + } + println!("---"); + tuf.union([!l(5), l(6)]); + tuf.make_repr(l(4).var()); + let renumber: IdVec> = + IdVec::from_vec(vec![Some(l(0)), None, None, Some(l(1)), Some(!l(1)), Some(!l(1)), Some(l(1)), Some(l(2)), Some(l(2))]); + let reverse = Renumbering::get_reverse(&renumber, &tuf.union_find); + dbg!(&renumber, &reverse); + tuf.renumber(renumber, reverse); + tuf.union([l(0), l(1)]); + let mut iter = tuf.drain_changes(&mut token); + println!("{:?}", iter.next()); + iter.stop(); + println!("---"); + for change in tuf.drain_changes(&mut token2).cloned() { + println!("{change:?}"); + } +} diff --git a/dsf/src/tests/test_union_find.rs b/dsf/src/tests/test_union_find.rs new file mode 100644 index 0000000..142a245 --- /dev/null +++ b/dsf/src/tests/test_union_find.rs @@ -0,0 +1,176 @@ +#![allow(dead_code, missing_docs)] + +use super::*; +use imctk_lit::{Var, Lit}; +use imctk_ids::id_set_seq::IdSetSeq; +use rand::prelude::*; +use std::collections::{HashSet, VecDeque}; + +#[derive(Default)] +struct CheckedUnionFind { + dut: UnionFind, + equivs: IdSetSeq, +} + +impl> UnionFind { + fn debug_print_tree( + children: &IdVec>, + atom: Atom, + prefix: &str, + self_char: &str, + further_char: &str, + pol: bool, + ) { + println!( + "{prefix}{self_char}{}{:?}", + if pol { "!" } else { "" }, + atom + ); + let my_children = children.get(atom).unwrap(); + for (index, &child) in my_children.iter().enumerate() { + let last = index == my_children.len() - 1; + let self_char = if last { "└" } else { "├" }; + let next_further_char = if last { " " } else { "│" }; + Self::debug_print_tree( + children, + child.atom(), + &(prefix.to_string() + further_char), + self_char, + next_further_char, + pol ^ (child != Elem::from_atom(child.atom())), + ); + } + } + fn debug_print(&self) { + let mut children: IdVec> = Default::default(); + for atom in self.parent.keys() { + let parent = self.read_parent(atom); + children.grow_for_key(atom); + if atom != parent.atom() { + children + .grow_for_key(parent.atom()) + .push(Elem::from_atom(atom).apply_pol_of(parent)); + } else { + assert!(Elem::from_atom(atom) == parent); + } + } + for atom in self.parent.keys() { + if atom == self.read_parent(atom).atom() { + Self::debug_print_tree(&children, atom, "", "", " ", false); + } + } + } +} +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +enum VarRel { + Equiv, + AntiEquiv, + NotEquiv, +} + +impl> CheckedUnionFind { + fn new() -> Self { + CheckedUnionFind { + dut: Default::default(), + equivs: Default::default(), + } + } + fn ref_equal(&mut self, start: Elem, goal: Elem) -> VarRel { + let mut seen: HashSet = Default::default(); + let mut queue: VecDeque = [start].into(); + while let Some(place) = queue.pop_front() { + if place.atom() == goal.atom() { + if place == goal { + return VarRel::Equiv; + } else { + return VarRel::AntiEquiv; + } + } + seen.insert(place.atom()); + for &next in self.equivs.grow_for(place.atom()).iter() { + if !seen.contains(&next.atom()) { + queue.push_back(next.apply_pol_of(place)); + } + } + } + VarRel::NotEquiv + } + fn find(&mut self, lit: Elem) -> Elem { + let out = self.dut.find(lit); + assert!(self.ref_equal(lit, out) == VarRel::Equiv); + out + } + fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, [ra, rb]) = self.dut.union_full(lits); + assert_eq!(self.ref_equal(lits[0], ra), VarRel::Equiv); + assert_eq!(self.ref_equal(lits[1], rb), VarRel::Equiv); + assert_eq!(ok, self.ref_equal(lits[0], lits[1]) == VarRel::NotEquiv); + assert_eq!(self.dut.find_root(lits[0]), ra); + if ok { + assert_eq!(self.dut.find_root(lits[1]), ra); + self.equivs + .grow_for(lits[0].atom()) + .insert(lits[1].apply_pol_of(lits[0])); + self.equivs + .grow_for(lits[1].atom()) + .insert(lits[0].apply_pol_of(lits[1])); + } else { + assert_eq!(self.dut.find_root(lits[1]).atom(), ra.atom()); + } + (ok, [ra, rb]) + } + fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + fn make_repr(&mut self, lit: Atom) { + self.dut.make_repr(lit); + assert_eq!( + self.dut.find_root(Elem::from_atom(lit)), + Elem::from_atom(lit) + ); + self.check(); + } + fn check(&mut self) { + for atom in self.dut.parent.keys() { + let parent = self.dut.read_parent(atom); + assert_eq!(self.ref_equal(Elem::from_atom(atom), parent), VarRel::Equiv); + let root = self.dut.find_root(Elem::from_atom(atom)); + for &child in self.equivs.grow_for(atom).iter() { + assert_eq!(root, self.dut.find_root(child)); + } + } + } +} + +#[test] +fn test() { + let mut u: CheckedUnionFind = CheckedUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let max_var = 2000; + for i in 0..2000 { + match rng.gen_range(0..10) { + 0..=4 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + println!("union({a}, {b}) = {result:?}"); + } + 5..=7 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + println!("find({a}) = {result}"); + } + 8 => { + u.check(); + } + 9 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + println!("make_repr({a})"); + } + _ => {} + } + } + u.check(); + //u.dut.debug_print(); +} diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs index 81f47d0..64bbfe5 100644 --- a/dsf/src/tracked_union_find.rs +++ b/dsf/src/tracked_union_find.rs @@ -1,28 +1,29 @@ -#![allow(missing_docs)] -#![allow(clippy::type_complexity)] -use std::{cmp::Reverse, collections::VecDeque, mem::ManuallyDrop, sync::Arc}; +//! A `TrackedUnionFind` augments a [`UnionFind`] structure with change tracking. +use std::{cmp::Reverse, collections::VecDeque, iter::FusedIterator, mem::ManuallyDrop, sync::Arc}; -use atomic::Atomic; -use bytemuck::NoUninit; -use imctk_ids::{id_vec::IdVec, Id, Id64}; +use imctk_ids::{id_vec::IdVec, Id, Id64, IdAlloc}; use priority_queue::PriorityQueue; -use crate::union_find::{Element, UnionFind}; +use crate::{Element, UnionFind}; +#[cfg(test)] +#[path = "tests/test_tracked_union_find.rs"] +mod test_tracked_union_find; + +/// Globally unique ID for a `TrackedUnionFind`. #[derive(Id, Debug)] #[repr(transparent)] pub struct TrackedUnionFindId(Id64); -// SAFETY: trust me bro -unsafe impl NoUninit for TrackedUnionFindId {} +/// Generation number for a `TrackedUnionFind` #[derive(Id, Debug)] #[repr(transparent)] pub struct Generation(u64); +/// Observer ID for a `TrackedUnionFind` (not globally unique) #[derive(Id, Debug)] #[repr(transparent)] pub struct ObserverId(Id64); -// SAFETY: trust me bro -unsafe impl NoUninit for ObserverId {} +/// An `ObserverToken` represents an observer of a `TrackedUnionFind`. #[derive(Debug)] pub struct ObserverToken { tuf_id: TrackedUnionFindId, @@ -30,6 +31,44 @@ pub struct ObserverToken { observer_id: ObserverId, } +impl ObserverToken { + /// Returns the ID of the associated `TrackedUnionFind`. + pub fn tuf_id(&self) -> TrackedUnionFindId { + self.tuf_id + } + /// Returns the ID of the generation that this observer is on. + pub fn generation(&self) -> Generation { + self.generation + } + /// Returns the ID of the observer. + /// + /// Note that observer IDs are local to each `TrackedUnionFind`. + pub fn observer_id(&self) -> ObserverId { + self.observer_id + } + /// Returns `true` iff `self` and `other` belong to the same `TrackedUnionFind`. + /// + /// NB: This does **not** imply that variable IDs are compatible, for that you want `is_compatible`. + pub fn is_same_tuf(&self, other: &ObserverToken) -> bool { + self.tuf_id == other.tuf_id + } + /// Returns `true` iff `self` and `other` have compatible variable IDs. + /// + /// This is equivalent to whether they are on the same generation of the same `TrackedUnionFind`. + pub fn is_compatible(&self, other: &ObserverToken) -> bool { + self.tuf_id == other.tuf_id && self.generation == other.generation + } +} + +/// `Renumbering` represents a renumbering of all variables in `TrackedUnionFind`. +/// +/// A renumbering stores a forward and a reverse mapping, and the old and new generation IDs. +/// +/// The forward mapping maps each old variable to an optional new variable (variables may be deleted). +/// It is a requirement that equivalent old variables are either both mapped to the same new variable or both deleted. +/// +/// The reverse mapping maps each new variable to its old representative. +/// The new set of variables is required to be contiguous, hence `reverse` is a total mapping. pub struct Renumbering { forward: IdVec>, reverse: IdVec, @@ -48,8 +87,12 @@ impl std::fmt::Debug for Renumbering { } } -impl + NoUninit> Renumbering { - pub fn get_reverse(forward: &IdVec>, union_find: &UnionFind) -> IdVec { +impl> Renumbering { + /// Returns the inverse of a reassignment of variables. + pub fn get_reverse( + forward: &IdVec>, + union_find: &UnionFind, + ) -> IdVec { let mut reverse: IdVec> = IdVec::default(); for (old, &new_opt) in forward { if let Some(new) = new_opt { @@ -60,6 +103,7 @@ impl + NoUninit> Renumbering { } IdVec::from_vec(reverse.iter().map(|x| x.1.unwrap()).collect()) } + /// Returns `true` iff the arguments are inverses of each other. pub fn is_inverse(forward: &IdVec>, reverse: &IdVec) -> bool { reverse.iter().all(|(new, &old)| { if let Some(&Some(e)) = forward.get(old.atom()) { @@ -69,6 +113,7 @@ impl + NoUninit> Renumbering { } }) } + /// Creates a renumbering without checking whether the arguments are valid. pub fn new_unchecked( forward: IdVec>, reverse: IdVec, @@ -82,6 +127,7 @@ impl + NoUninit> Renumbering { new_generation, } } + /// Creates a new renumbering from the given forward and reverse assignment. pub fn new( forward: IdVec>, reverse: IdVec, @@ -92,6 +138,7 @@ impl + NoUninit> Renumbering { debug_assert!(Self::is_inverse(&forward, &reverse)); Self::new_unchecked(forward, reverse, old_generation, new_generation) } + /// Returns the new variable corresponding to the given old variable, if it exists. pub fn old_to_new(&self, old: Elem) -> Option { self.forward .get(old.atom()) @@ -99,9 +146,11 @@ impl + NoUninit> Renumbering { .flatten() .map(|e| e.apply_pol_of(old)) } + /// Returns the old variable corresponding to the given new variable, if it exists. pub fn new_to_old(&self, new: Elem) -> Option { self.reverse.get(new.atom()).map(|&e| e.apply_pol_of(new)) } + /// Returns `true` iff the given renumbering satisfies the constraint that equivalent variables are mapped identically. pub fn is_repr_reduction(&self, union_find: &UnionFind) -> bool { union_find.lowest_unused_atom() <= self.forward.next_unused_key() && self.forward.iter().all(|(old, &new)| { @@ -112,10 +161,15 @@ impl + NoUninit> Renumbering { } } +/// `Change` represents a single change of a `TrackedUnionFind` +#[allow(missing_docs)] // dont want to document every subfield #[derive(Clone)] pub enum Change { + /// A `union` operation. The set with representative `merged_repr` is merged into the set with representative `new_repr`. Union { new_repr: Atom, merged_repr: Elem }, + /// A `make_repr` operation. `new_repr` is promoted to be the representative of its set, replacing `old_repr`. MakeRepr { new_repr: Atom, old_repr: Elem }, + /// A renumbering operation. Renumber(Arc>), } @@ -140,6 +194,7 @@ impl std::fmt::Debug for Change { } } +/// A `TrackedUnionFind` augments a [`UnionFind`] structure with change tracking. pub struct TrackedUnionFind { tuf_id: TrackedUnionFindId, union_find: UnionFind, @@ -150,46 +205,13 @@ pub struct TrackedUnionFind { generation: Generation, } -pub struct IdAlloc { - counter: Atomic, -} - -impl Default for IdAlloc { - fn default() -> Self { - Self::new() - } -} - -impl IdAlloc { - const fn new() -> Self { - Self { - counter: Atomic::new(T::MIN_ID), - } - } - pub fn alloc_block(&self, n: usize) -> T { - use atomic::Ordering::Relaxed; - debug_assert!(n > 0); - self.counter - .fetch_update(Relaxed, Relaxed, |current_id| { - current_id - .id_index() - .checked_add(n) - .and_then(T::try_from_id_index) - }) - .expect("not enough IDs remaining") - } - pub fn alloc(&self) -> T { - self.alloc_block(1) - } -} - static TUF_ID_ALLOC: IdAlloc = IdAlloc::new(); -impl Default for TrackedUnionFind { - fn default() -> Self { +impl From> for TrackedUnionFind { + fn from(union_find: UnionFind) -> Self { Self { - tuf_id: TUF_ID_ALLOC.alloc(), - union_find: Default::default(), + union_find, + tuf_id: TUF_ID_ALLOC.alloc().unwrap(), log: Default::default(), observer_id_alloc: Default::default(), observers: Default::default(), @@ -199,12 +221,31 @@ impl Default for TrackedUnionFind { } } -impl + NoUninit> TrackedUnionFind { +impl Default for TrackedUnionFind { + fn default() -> Self { + UnionFind::default().into() + } +} + +impl TrackedUnionFind { + /// Returns a shared reference to the contained `UnionFind`. + pub fn get_union_find(&self) -> &UnionFind { + &self.union_find + } + /// Returns the contained `UnionFind`. All change tracking data is lost. + pub fn into_union_find(self) -> UnionFind { + self.union_find + } +} + +impl> TrackedUnionFind { + /// Returns an element's representative. See [`UnionFind::find`]. pub fn find(&self, elem: Elem) -> Elem { self.union_find.find(elem) } - pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { - let (ok, roots) = self.union_find.union_full(lits); + /// Declares two elements to be equivalent. See [`UnionFind::union_full`]. + pub fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, roots) = self.union_find.union_full(elems); if ok && !self.observers.is_empty() { let new_repr = roots[0].atom(); let merged_repr = roots[1].apply_pol_of(roots[0]); @@ -215,9 +256,11 @@ impl + NoUninit> TrackedUnionFind } (ok, roots) } + /// Declares two elements to be equivalent. See [`UnionFind::union`]. pub fn union(&mut self, lits: [Elem; 2]) -> bool { self.union_full(lits).0 } + /// Promotes an atom to be a representative. See [`UnionFind::make_repr`]. pub fn make_repr(&mut self, new_repr: Atom) -> Elem { let old_repr = self.union_find.make_repr(new_repr); if old_repr.atom() != new_repr && !self.observers.is_empty() { @@ -225,6 +268,12 @@ impl + NoUninit> TrackedUnionFind } old_repr } + /// Renumbers all the variables in the `UnionFind`. + /// The provided mapping **must** meet all the preconditions listed for [`Renumbering`]. + /// + /// This resets the `UnionFind` to the trivial state (`find(a) == a` for all `a`) and increments the generation ID. + /// + /// This method will panic in debug mode if said preconditions are not met. pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) { let old_generation = self.generation; let new_generation = Generation(old_generation.0 + 1); @@ -239,14 +288,23 @@ impl + NoUninit> TrackedUnionFind } impl TrackedUnionFind { + /// Constructs a new, empty `TrackedUnionFind`. pub fn new() -> Self { Self::default() } fn log_end(&self) -> u64 { self.log_start + self.log.len() as u64 } + /// Creates a new `ObserverToken` that can be used to track all changes since the call to this method. + /// + /// Conceptually, each observer has its own private log. + /// Any changes that happen to the `UnionFind` will be recorded into the logs of all currently active observers. + /// (In the actual implementation, only a single log is kept). + /// + /// After use, the `ObserverToken` must be disposed of with a call to `stop_observing`, otherwise + /// the memory corresponding to old log entries cannot be reclaimed until the `TrackedUnionFind` is dropped. pub fn start_observing(&mut self) -> ObserverToken { - let observer_id = self.observer_id_alloc.alloc(); + let observer_id = self.observer_id_alloc.alloc().unwrap(); self.observers.push(observer_id, Reverse(self.log_end())); ObserverToken { tuf_id: self.tuf_id, @@ -254,9 +312,10 @@ impl TrackedUnionFind { observer_id, } } + /// Clones an `ObserverToken`, conceptually cloning the token's private log. pub fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { assert!(token.tuf_id == self.tuf_id); - let new_observer_id = self.observer_id_alloc.alloc(); + let new_observer_id = self.observer_id_alloc.alloc().unwrap(); let pos = *self.observers.get_priority(&token.observer_id).unwrap(); self.observers.push(new_observer_id, pos); ObserverToken { @@ -265,6 +324,9 @@ impl TrackedUnionFind { observer_id: new_observer_id, } } + /// Deletes an `ObserverToken` and its associated state. + /// + /// You must call this to allow the `TrackedUnionFind` to allow memory to be reclaimed. pub fn stop_observing(&mut self, token: ObserverToken) { assert!(token.tuf_id == self.tuf_id); self.observers.remove(&token.observer_id); @@ -275,13 +337,11 @@ impl TrackedUnionFind { if new_start > self.log_start { let delete = (new_start - self.log_start).try_into().unwrap(); drop(self.log.drain(0..delete)); - println!("dropped {delete} entries"); self.log_start = new_start; } } else { self.log_start = self.log_end(); self.log.clear(); - println!("dropped all entries"); } } fn observer_rel_pos(&self, token: &ObserverToken) -> usize { @@ -307,6 +367,7 @@ impl TrackedUnionFind { .change_priority(&token.observer_id, Reverse(abs_pos)); self.truncate_log(); } + #[allow(clippy::type_complexity)] fn change_slices( &self, token: &ObserverToken, @@ -331,6 +392,17 @@ impl TrackedUnionFind { } } } + /// Calls the provided function `f` with the content of the token's private log and clears the log. + /// + /// Because the log is not necessarily contiguous in memory, `f` may be called multiple times. + /// + /// The slice argument to `f` is guaranteed to be non-empty. + /// To allow looking up representatives `f` is also provided with a shared reference to the `UnionFind`. + /// + /// The method assumes that you will immediately process any `Renumbering` operations in the log + /// and will update the token's generation field. + /// + /// Returns `true` iff `f` has been called at least once. pub fn drain_changes_with_fn( &mut self, token: &mut ObserverToken, @@ -351,6 +423,11 @@ impl TrackedUnionFind { false } } + /// Returns a draining iterator that returns and deletes entries from the token's private log. + /// + /// Dropping this iterator will clear any unread entries, call `stop` if this is undesirable. + /// + /// You must not leak the returned iterator. Otherwise log entries may be observed multiple times and appear duplicated. pub fn drain_changes<'a>( &mut self, token: &'a mut ObserverToken, @@ -365,21 +442,37 @@ impl TrackedUnionFind { } } +/// A draining iterator. +/// +/// Since this is a lending iterator, it does not implement the standard `Iterator` trait, +/// but its `map` and `cloned` methods will create a standard iterator. pub struct DrainChanges<'a, 'b, Atom, Elem> { tuf: &'a mut TrackedUnionFind, token: &'b mut ObserverToken, rel_pos: usize, } +/// A draining iterator that has been mapped. pub struct DrainChangesMap<'a, 'b, Atom, Elem, F> { inner: DrainChanges<'a, 'b, Atom, Elem>, f: F, } impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { + /// Returns a reference to the first entry in the token's private log, without deleting it. + /// + /// Returns `None` if the log is empty. pub fn peek(&mut self) -> Option<&Change> { self.tuf.log.get(self.rel_pos) } + /// Returns a reference to the first entry in the token's private log. The entry will be deleted after its use. + /// + /// If this returns a `Renumbering`, it is assumed that you will process it and the token's generation number will be updated. + /// + /// Returns `None` if the log is empty. If `next` returned `None`, it will never return any more entries (the iterator is fused). + /// + /// As reflected by the lifetimes, the API only guarantees that the returned reference until the next call of any method of this iterator. + /// (In practice, deletion is more lazy and happens on drop). #[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Option<&Change> { let ret = self.tuf.log.get(self.rel_pos); @@ -390,14 +483,20 @@ impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { } ret } + /// Drops the iterator but without deleting unread entries. pub fn stop(self) { self.tuf.observer_set_rel_pos(self.token, self.rel_pos); let _ = ManuallyDrop::new(self); } + /// Returns `(n, Some(n))` where `n` is the number of unread entries. + /// + /// This method is designed to be compatible with the standard iterator method of the same name. pub fn size_hint(&self) -> (usize, Option) { let count = self.tuf.log.len() - self.rel_pos; (count, Some(count)) } + /// Creates a new iterator by lazily calling `f` on every change. + #[must_use] pub fn map(self, f: F) -> DrainChangesMap<'a, 'b, Atom, Elem, F> where F: FnMut(&Change) -> B, @@ -407,6 +506,9 @@ impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { } impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { + /// Create a standard iterator by cloning every entry. + #[must_use] + #[allow(clippy::type_complexity)] pub fn cloned( self, ) -> DrainChangesMap<'a, 'b, Atom, Elem, fn(&Change) -> Change> { @@ -416,6 +518,9 @@ impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { impl Drop for DrainChanges<'_, '_, Atom, Elem> { fn drop(&mut self) { + // mark any renumberings as seen + while self.next().is_some() { + } self.tuf .observer_set_rel_pos(self.token, self.tuf.log.len()); } @@ -435,33 +540,12 @@ where } } -#[test] -fn test() { - use imctk_lit::{Lit, Var}; - let l = |n| Var::from_index(n).as_lit(); - let mut tuf = TrackedUnionFind::::new(); - let mut token = tuf.start_observing(); - tuf.union([l(3), !l(4)]); - tuf.union([l(8), l(7)]); - let mut token2 = tuf.start_observing(); - tuf.union([l(4), l(5)]); - for change in tuf.drain_changes(&mut token).cloned() { - println!("{change:?}"); - } - println!("---"); - tuf.union([!l(5), l(6)]); - tuf.make_repr(l(4).var()); - let renumber: IdVec> = - IdVec::from_vec(vec![Some(l(0)), None, None, Some(l(1)), Some(!l(1)), Some(!l(1)), Some(l(1)), Some(l(2)), Some(l(2))]); - let reverse = Renumbering::get_reverse(&renumber, &tuf.union_find); - dbg!(&renumber, &reverse); - tuf.renumber(renumber, reverse); - tuf.union([l(0), l(1)]); - let mut iter = tuf.drain_changes(&mut token); - println!("{:?}", iter.next()); - iter.stop(); - println!("---"); - for change in tuf.drain_changes(&mut token2).cloned() { - println!("{change:?}"); - } +impl ExactSizeIterator for DrainChangesMap<'_, '_, Atom, Elem, F> where + F: FnMut(&Change) -> B +{ +} + +impl FusedIterator for DrainChangesMap<'_, '_, Atom, Elem, F> where + F: FnMut(&Change) -> B +{ } diff --git a/dsf/src/union_find.rs b/dsf/src/union_find.rs index 5036432..ebc54ee 100644 --- a/dsf/src/union_find.rs +++ b/dsf/src/union_find.rs @@ -1,41 +1,53 @@ -#![allow(missing_docs)] -use std::sync::atomic::Ordering; - +//! `UnionFind` efficiently tracks equivalences between variables. +use crate::Element; use atomic::Atomic; -use bytemuck::NoUninit; -use imctk_ids::{id_vec::IdVec, Id}; -use imctk_lit::{Lit, Var}; - -pub trait Element { - fn from_atom(atom: Atom) -> Self; - fn atom(self) -> Atom; - fn apply_pol_of(self, other: Self) -> Self; -} - -impl Element for T { - fn from_atom(atom: T) -> Self { - atom - } - fn atom(self) -> T { - self - } - fn apply_pol_of(self, _other: T) -> Self { - self - } -} +use imctk_ids::{id_vec::IdVec, Id, IdRange}; +use std::sync::atomic::Ordering; -impl Element for Lit { - fn from_atom(atom: Var) -> Self { - atom.as_lit() - } - fn atom(self) -> Var { - self.var() - } - fn apply_pol_of(self, other: Self) -> Self { - self ^ other.pol() - } -} +#[cfg(test)] +#[path = "tests/test_union_find.rs"] +mod test_union_find; +/// `UnionFind` efficiently tracks equivalences between variables. +/// +/// Given "elements" of type `Elem`, this structure keeps track of any known equivalences between these elements. +/// Equivalences are assumed to be *transitive*, i.e. if `x = y` and `y = z`, then `x = z` is assumed to be true, and in fact, +/// automatically discovered by this structure. +/// +/// Unlike a standard union find data structure, this version also keeps track of the *polarity* of elements. +/// For example, you might always have pairs of elements `+x` and `-x` that are exact opposites of each other. +/// If equivalences between `+x` and `-y` are then discovered, the structure also understands that `-x` and `+y` are similarly equivalent. +/// +/// An element without its polarity is called an *atom*. +/// The [`Element`](Element) trait is required on `Elem` so that this structure can relate elements, atoms and polarities. +/// +/// For each set of elements that are equivalent up to a polarity change, this structure keeps track of a *representative*. +/// Each element starts out as its own representative, and if two elements are declared equivalent, the representative of one becomes the representative of both. +/// The method `find` returns the representative for any element. +/// +/// To declare two elements as equivalent, use the `union` or `union_full` methods, see their documentation for details on their use. +/// +/// NB: Since this structure stores atoms in an `IdVec`, the atoms used should ideally be a contiguous set starting at `Atom::MIN_ID`. +/// +/// ## Example ## +/// ``` +/// use imctk_lit::{Var, Lit}; +/// use dsf::UnionFind; +/// +/// let mut union_find: UnionFind = UnionFind::new(); +/// let lit = |n| Var::from_index(n).as_lit(); +/// +/// assert_eq!(union_find.find(lit(4)), lit(4)); +/// +/// union_find.union([lit(3), lit(4)]); +/// assert_eq!(union_find.find(lit(4)), lit(3)); +/// +/// union_find.union([lit(1), !lit(2)]); +/// union_find.union([lit(2), lit(3)]); +/// assert_eq!(union_find.find(lit(1)), lit(1)); +/// assert_eq!(union_find.find(lit(4)), !lit(1)); +/// +/// ``` pub struct UnionFind { parent: IdVec>, } @@ -48,7 +60,7 @@ impl Default for UnionFind { } } -impl Clone for UnionFind { +impl Clone for UnionFind { fn clone(&self) -> Self { let new_parent = self .parent @@ -62,36 +74,63 @@ impl Clone for UnionFind { } } -impl + NoUninit> UnionFind { +impl UnionFind { + /// Constructs an empty `UnionFind`. + /// + /// The returned struct obeys `find(a) == a` for all `a`. pub fn new() -> Self { UnionFind::default() } +} + +impl> UnionFind { + /// Constructs an empty `UnionFind`, with room for `capacity` elements. + pub fn with_capacity(capacity: usize) -> Self { + UnionFind { + parent: IdVec::from_vec(Vec::with_capacity(capacity)), + } + } + /// Clears all equivalences, but retains any allocated memory. + pub fn clear(&mut self) { + self.parent.clear(); + } fn read_parent(&self, atom: Atom) -> Elem { - self.parent - .get(atom) - .map(|p| p.load(Ordering::Relaxed)) - .unwrap_or(Elem::from_atom(atom)) + if let Some(parent_cell) = self.parent.get(atom) { + // This load is allowed to reorder with stores from `update_parent`, see there for details. + parent_cell.load(Ordering::Relaxed) + } else { + Elem::from_atom(atom) + } } + // Important: Only semantically trivial changes are allowed using this method!! + // Specifically, update_parent(atom, parent) should only be called if `parent` is already an ancestor of `atom` + // Otherwise, concurrent calls to `read_parent` (which are explicitly allowed!) could return incorrect results. fn update_parent(&self, atom: Atom, parent: Elem) { - let Some(parent_ref) = self.parent.get(atom) else { + if let Some(parent_cell) = self.parent.get(atom) { + parent_cell.store(parent, Ordering::Relaxed); + } else { + // can only get here if the precondition or a data structure invariant is violated panic!("shouldn't happen: update_parent called with out of bounds argument"); - }; - parent_ref.store(parent, Ordering::Relaxed); + } } + // Unlike `update_parent`, this is safe for arbitrary updates, since it requires &mut self. fn write_parent(&mut self, atom: Atom, parent: Elem) { - if let Some(parent_ref) = self.parent.get(atom) { - parent_ref.store(parent, Ordering::Relaxed); + if let Some(parent_cell) = self.parent.get(atom) { + parent_cell.store(parent, Ordering::Relaxed); } else { debug_assert!(self.parent.next_unused_key() <= atom); while self.parent.next_unused_key() < atom { - self.parent - .push(Atomic::new(Elem::from_atom(self.parent.next_unused_key()))); + let next_elem = Elem::from_atom(self.parent.next_unused_key()); + self.parent.push(Atomic::new(next_elem)); } self.parent.push(Atomic::new(parent)); } } fn find_root(&self, mut elem: Elem) -> Elem { loop { + // If we interleave with a call to `update_parent`, the parent may change + // under our feet, but it's okay because in that case we get an ancestor instead, + // which just skips some iterations of the loop! let parent = self.read_parent(elem.atom()).apply_pol_of(elem); if elem == parent { return elem; @@ -100,220 +139,103 @@ impl + NoUninit> UnionFind { elem = parent; } } + // Worst-case `find` performance is linear. To keep amortised time complexity logarithmic, + // we memoise the result of `find_root` by calling `update_parent` on every element + // we traversed. fn update_root(&self, mut elem: Elem, root: Elem) { + // Loop invariant: `root` is the representative of `elem`. loop { + // Like in `find_root`, this may interleave with `update_root` calls, and we may skip some steps, + // which is okay because the other thread will do the updates instead. let parent = self.read_parent(elem.atom()).apply_pol_of(elem); if parent == root { break; } + // By the loop invariant, this just sets `elem`'s parent to its representative, + // which satisfies the precondition for `update_parent`. Further if two threads end up + // here simultaneously, they will both set to the same representative, + // therefore the change is idempotent. self.update_parent(elem.atom(), root.apply_pol_of(elem)); elem = parent; } } - pub fn find(&self, lit: Elem) -> Elem { - let root = self.find_root(lit); - self.update_root(lit, root); + /// Returns the representative for an element. Elements are equivalent iff they have the same representative. + /// + /// Elements `a` and `b` are equivalent up to a polarity change iff they obey `find(a) = find(b) ^ p` for some polarity `p`. + /// + /// This operation is guaranteed to return `elem` itself for arguments `elem >= lowest_unused_atom()`. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn find(&self, elem: Elem) -> Elem { + let root = self.find_root(elem); + self.update_root(elem, root); root } - pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { - let [a, b] = lits; + /// Declares two elements to be equivalent. The new representative of both is the representative of the first element. + /// + /// If the elements are already equivalent or cannot be made equivalent (are equivalent up to a sign change), + /// the operation returns `false` without making any changes. Otherwise it returns `true`. + /// + /// In both cases it also returns the original representatives of both arguments. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let [a, b] = elems; let ra = self.find(a); let rb = self.find(b); if ra.atom() == rb.atom() { (false, [ra, rb]) } else { + // The first write is only needed to ensure that the parent table actually contains `a` + // and is a no-op otherwise. + self.write_parent(ra.atom(), Elem::from_atom(ra.atom())); self.write_parent(rb.atom(), ra.apply_pol_of(rb)); (true, [ra, rb]) } } - pub fn union(&mut self, lits: [Elem; 2]) -> bool { - self.union_full(lits).0 + /// Declares two elements to be equivalent. The new representative of both is the representative of the first element. + /// + /// If the elements are already equivalent or cannot be made equivalent (are equivalent up to polarity), + /// the operation returns `false` without making any changes. Otherwise it returns `true`. + /// + /// The amortised time complexity of this operation is **O**(log N). + pub fn union(&mut self, elems: [Elem; 2]) -> bool { + self.union_full(elems).0 } + /// Sets `atom` to be its own representative, and updates other representatives to preserve all existing equivalences. + /// + /// The amortised time complexity of this operation is **O**(log N). pub fn make_repr(&mut self, atom: Atom) -> Elem { let root = self.find(Elem::from_atom(atom)); self.write_parent(atom, Elem::from_atom(atom)); self.write_parent(root.atom(), Elem::from_atom(atom).apply_pol_of(root)); root } + /// Returns the lowest `Atom` value for which no equivalences are known. + /// + /// It is guaranteed that `find(a) == a` if `a >= lowest_unused_atom`, but the converse may not hold. pub fn lowest_unused_atom(&self) -> Atom { self.parent.next_unused_key() } -} - -#[cfg(test)] -#[allow(dead_code)] -mod tests { - use super::*; - use imctk_ids::id_set_seq::IdSetSeq; - use rand::prelude::*; - use std::collections::{HashSet, VecDeque}; - - #[derive(Default)] - struct CheckedUnionFind { - dut: UnionFind, - equivs: IdSetSeq, - } - - impl + NoUninit> UnionFind { - fn debug_print_tree( - children: &IdVec>, - atom: Atom, - prefix: &str, - self_char: &str, - further_char: &str, - pol: bool, - ) { - println!( - "{prefix}{self_char}{}{:?}", - if pol { "!" } else { "" }, - atom - ); - let my_children = children.get(atom).unwrap(); - for (index, &child) in my_children.iter().enumerate() { - let last = index == my_children.len() - 1; - let self_char = if last { "└" } else { "├" }; - let next_further_char = if last { " " } else { "│" }; - Self::debug_print_tree( - children, - child.atom(), - &(prefix.to_string() + further_char), - self_char, - next_further_char, - pol ^ (child != Elem::from_atom(child.atom())), - ); - } - } - fn debug_print(&self) { - let mut children: IdVec> = Default::default(); - for atom in self.parent.keys() { - let parent = self.read_parent(atom); - children.grow_for_key(atom); - if atom != parent.atom() { - children - .grow_for_key(parent.atom()) - .push(Elem::from_atom(atom).apply_pol_of(parent)); - } else { - assert!(Elem::from_atom(atom) == parent); - } - } - for atom in self.parent.keys() { - if atom == self.read_parent(atom).atom() { - Self::debug_print_tree(&children, atom, "", "", " ", false); - } - } - } - } - #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] - enum VarRel { - Equiv, - AntiEquiv, - NotEquiv, - } - - impl + NoUninit> CheckedUnionFind { - fn new() -> Self { - CheckedUnionFind { - dut: Default::default(), - equivs: Default::default(), - } - } - fn ref_equal(&mut self, start: Elem, goal: Elem) -> VarRel { - let mut seen: HashSet = Default::default(); - let mut queue: VecDeque = [start].into(); - while let Some(place) = queue.pop_front() { - if place.atom() == goal.atom() { - if place == goal { - return VarRel::Equiv; - } else { - return VarRel::AntiEquiv; - } - } - seen.insert(place.atom()); - for &next in self.equivs.grow_for(place.atom()).iter() { - if !seen.contains(&next.atom()) { - queue.push_back(next.apply_pol_of(place)); - } - } - } - VarRel::NotEquiv - } - fn find(&mut self, lit: Elem) -> Elem { - let out = self.dut.find(lit); - assert!(self.ref_equal(lit, out) == VarRel::Equiv); - out - } - fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { - let (ok, [ra, rb]) = self.dut.union_full(lits); - assert_eq!(self.ref_equal(lits[0], ra), VarRel::Equiv); - assert_eq!(self.ref_equal(lits[1], rb), VarRel::Equiv); - assert_eq!(ok, self.ref_equal(lits[0], lits[1]) == VarRel::NotEquiv); - assert_eq!(self.dut.find_root(lits[0]), ra); - if ok { - assert_eq!(self.dut.find_root(lits[1]), ra); - self.equivs - .grow_for(lits[0].atom()) - .insert(lits[1].apply_pol_of(lits[0])); - self.equivs - .grow_for(lits[1].atom()) - .insert(lits[0].apply_pol_of(lits[1])); - } else { - assert_eq!(self.dut.find_root(lits[1]).atom(), ra.atom()); - } - (ok, [ra, rb]) - } - fn union(&mut self, lits: [Elem; 2]) -> bool { - self.union_full(lits).0 - } - fn make_repr(&mut self, lit: Atom) { - self.dut.make_repr(lit); - assert_eq!( - self.dut.find_root(Elem::from_atom(lit)), - Elem::from_atom(lit) - ); - self.check(); - } - fn check(&mut self) { - for atom in self.dut.parent.keys() { - let parent = self.dut.read_parent(atom); - assert_eq!(self.ref_equal(Elem::from_atom(atom), parent), VarRel::Equiv); - let root = self.dut.find_root(Elem::from_atom(atom)); - for &child in self.equivs.grow_for(atom).iter() { - assert_eq!(root, self.dut.find_root(child)); - } - } - } + /// Returns an iterator that yields all tracked atoms and their representatives. + pub fn iter(&self) -> impl '_ + Iterator { + IdRange::from(Atom::MIN_ID..self.lowest_unused_atom()) + .iter() + .map(|atom| (atom, self.find(Elem::from_atom(atom)))) } +} - #[test] - fn test() { - let mut u: CheckedUnionFind = CheckedUnionFind::new(); - let mut rng = rand_pcg::Pcg64::seed_from_u64(25); - let max_var = 2000; - for i in 0..2000 { - match rng.gen_range(0..10) { - 0..=4 => { - let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); - let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); - let result = u.union_full([a, b]); - println!("union({a}, {b}) = {result:?}"); - } - 5..=7 => { - let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); - let result = u.find(a); - println!("find({a}) = {result}"); - } - 8 => { - u.check(); - } - 9 => { - let a = Var::from_index(rng.gen_range(0..=max_var)); - u.make_repr(a); - println!("make_repr({a})"); - } - _ => {} +impl> std::fmt::Debug for UnionFind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // prints non-trivial sets of equivalent elements, always printing the representative first + let mut sets = std::collections::HashMap::>::new(); + for (atom, repr) in self.iter() { + if Elem::from_atom(atom) != repr { + sets.entry(repr.atom()) + .or_insert_with(|| vec![Elem::from_atom(repr.atom())]) + .push(Elem::from_atom(atom).apply_pol_of(repr)); } } - u.check(); - //u.dut.debug_print(); + f.debug_set().entries(sets.values()).finish() } } diff --git a/ids/src/id_alloc.rs b/ids/src/id_alloc.rs new file mode 100644 index 0000000..630ed8f --- /dev/null +++ b/ids/src/id_alloc.rs @@ -0,0 +1,54 @@ +//! An allocator for IDs that allows concurrent access from multiple threads. +use std::sync::atomic::Ordering::Relaxed; +use std::{marker::PhantomData, sync::atomic::AtomicUsize}; + +use crate::{Id, IdRange}; + +/// An allocator for IDs that allows concurrent access from multiple threads. +pub struct IdAlloc { + counter: AtomicUsize, + _phantom: PhantomData, +} + +impl Default for IdAlloc { + fn default() -> Self { + Self::new() + } +} + +/// `IdAllocError` indicates that there are not enough IDs remaining. +#[derive(Clone, Copy, Debug)] +pub struct IdAllocError; + +impl IdAlloc { + /// Constructs a new ID allocator. + pub const fn new() -> Self { + Self { + counter: AtomicUsize::new(0), + _phantom: PhantomData, + } + } + fn alloc_indices(&self, n: usize) -> Result { + self.counter + .fetch_update(Relaxed, Relaxed, |current_id| { + current_id + .checked_add(n) + .filter(|&index| index <= T::MAX_ID_INDEX.saturating_add(1)) + }) + .map_err(|_| IdAllocError) + } + /// Allocates a single ID. + pub fn alloc(&self) -> Result { + self.alloc_indices(1).map(|index| { + // SAFETY: the precondition was checked by `alloc_indices` + unsafe { T::from_id_index_unchecked(index) } + }) + } + /// Allocates a contiguous range of the specified size. + pub fn alloc_range(&self, n: usize) -> Result, IdAllocError> { + self.alloc_indices(n).map(|start| { + // SAFETY: the precondition was checked by `alloc_indices` + unsafe { IdRange::from_index_range_unchecked(start..start + n) } + }) + } +} diff --git a/ids/src/lib.rs b/ids/src/lib.rs index a65669b..da26bd5 100644 --- a/ids/src/lib.rs +++ b/ids/src/lib.rs @@ -16,6 +16,7 @@ mod id; mod id_range; pub mod id_vec; pub mod indexed_id_vec; +pub mod id_alloc; pub mod id_set_seq; @@ -32,6 +33,8 @@ pub use id::{ConstIdFromIdIndex, GenericId, Id, Id16, Id32, Id64, Id8, IdSize}; pub use id_range::IdRange; +pub use id_alloc::IdAlloc; + // re-export this so that others can use it without depending on bytemuck explicitly // in particular needed for #[derive(Id)] pub use bytemuck::NoUninit; \ No newline at end of file From 75cbe5a07d2a78fd03bb204bd4d8c8b727e21896 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Wed, 16 Oct 2024 10:53:22 +0100 Subject: [PATCH 5/7] dsf: tests for tracked_union_find, docs and other minor improvements. --- dsf/src/tests/test_tracked_union_find.rs | 307 +++++++++++++++++++++-- dsf/src/tests/test_union_find.rs | 4 +- dsf/src/tracked_union_find.rs | 41 ++- dsf/src/union_find.rs | 2 +- 4 files changed, 315 insertions(+), 39 deletions(-) diff --git a/dsf/src/tests/test_tracked_union_find.rs b/dsf/src/tests/test_tracked_union_find.rs index 24f8d5c..087c5ab 100644 --- a/dsf/src/tests/test_tracked_union_find.rs +++ b/dsf/src/tests/test_tracked_union_find.rs @@ -1,32 +1,287 @@ +#![allow(missing_docs, dead_code)] use super::*; use imctk_lit::{Lit, Var}; +use rand::prelude::*; +use std::collections::HashMap; + +fn change_eq(c1: &Change, c2: &Change) -> bool { + match (c1, c2) { + ( + Change::Union { + new_repr: a1, + merged_repr: b1, + }, + Change::Union { + new_repr: a2, + merged_repr: b2, + }, + ) => a1 == a2 && b1 == b2, + ( + Change::MakeRepr { + new_repr: a1, + old_repr: b1, + }, + Change::MakeRepr { + new_repr: a2, + old_repr: b2, + }, + ) => a1 == a2 && b1 == b2, + (Change::Renumber(r1), Change::Renumber(r2)) => Arc::as_ptr(r1) == Arc::as_ptr(r2), + _ => false, + } +} + +struct CTUnionFind { + dut: TrackedUnionFind, + uf: UnionFind, + logs: HashMap>>, +} + +impl> CTUnionFind { + fn new() -> Self { + CTUnionFind { + dut: TrackedUnionFind::new(), + uf: UnionFind::new(), + logs: HashMap::new(), + } + } + fn find(&mut self, e: Elem) -> Elem { + let dut_result = self.dut.find(e); + let uf_result = self.uf.find(e); + assert_eq!(dut_result, uf_result); + uf_result + } + fn union_full(&mut self, elems: [Elem; 2]) -> (bool, [Elem; 2]) { + let (uf_ok, [uf_ra, uf_rb]) = self.uf.union_full(elems); + if uf_ok { + let new_repr = uf_ra.atom(); + let merged_repr = uf_rb.apply_pol_of(uf_ra); + for vec in self.logs.values_mut() { + vec.push(Change::Union { + new_repr, + merged_repr, + }); + } + } + let (dut_ok, [dut_ra, dut_rb]) = self.dut.union_full(elems); + assert_eq!((dut_ok, [dut_ra, dut_rb]), (uf_ok, [uf_ra, uf_rb])); + (uf_ok, [uf_ra, uf_rb]) + } + fn make_repr(&mut self, new_repr: Atom) -> Elem { + let uf_old_repr = self.uf.make_repr(new_repr); + if uf_old_repr.atom() != new_repr { + for vec in self.logs.values_mut() { + vec.push(Change::MakeRepr { + new_repr, + old_repr: uf_old_repr, + }); + } + } else { + assert!(uf_old_repr == Elem::from_atom(new_repr)); + } + let dut_old_repr = self.dut.make_repr(new_repr); + assert_eq!(dut_old_repr, uf_old_repr); + uf_old_repr + } + fn start_observing(&mut self) -> ObserverToken { + let token = self.dut.start_observing(); + assert!(self.logs.insert(token.observer_id, Vec::new()).is_none()); + token + } + fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { + let new_token = self.dut.clone_token(token); + let cloned_log = self.logs.get(&token.observer_id).unwrap().clone(); + assert!(self + .logs + .insert(new_token.observer_id, cloned_log) + .is_none()); + new_token + } + fn stop_observing(&mut self, token: ObserverToken) { + assert!(self.logs.remove(&token.observer_id).is_some()); + self.dut.stop_observing(token); + } + fn drain_changes_with_fn( + &mut self, + token: &mut ObserverToken, + mut f: impl FnMut(&[Change]), + ) { + let log = self.logs.get_mut(&token.observer_id).unwrap(); + let mut log_iter = log.drain(..); + self.dut.drain_changes_with_fn(token, |changes, _| { + f(changes); + for c1 in changes.iter() { + let Some(c2) = log_iter.next() else { + panic!("not enough changes"); + }; + assert!(change_eq(c1, &c2)); + } + }); + assert!(log_iter.next().is_none()); + } + fn drain_some_changes( + &mut self, + token: &mut ObserverToken, + calculate_count: impl FnOnce(usize) -> usize, + ) -> usize { + let log = self.logs.get_mut(&token.observer_id).unwrap(); + let count = calculate_count(log.len()); + let log_iter = log.drain(..count); + let mut gen = token.generation; + let mut dut_iter = self.dut.drain_changes(token); + for c2 in log_iter { + let Some(c1) = dut_iter.next() else { + panic!("not enough changes"); + }; + assert!(change_eq(c1, &c2)); + if let Change::Renumber(renumbering) = c1 { + assert_eq!(renumbering.old_generation, gen); + gen = renumbering.new_generation; + } + } + dut_iter.stop(); + assert_eq!(token.generation, gen); + count + } + fn repr_reduction(&self) -> (IdVec>, IdVec) { + let mut forward: IdVec> = IdVec::default(); + let mut reverse: IdVec = IdVec::default(); + for (atom, repr) in self.uf.iter() { + let new_repr = *forward.grow_for_key(repr.atom()).get_or_insert_with(|| { + let new_repr = reverse.push(Elem::from_atom(repr.atom())).0; + Elem::from_atom(new_repr) + }); + forward + .grow_for_key(atom) + .replace(new_repr.apply_pol_of(repr)); + } + (forward, reverse) + } + fn renumber(&mut self) -> Arc> { + let (forward, reverse) = self.repr_reduction(); + let renumbering = self.dut.renumber(forward, reverse); + for vec in self.logs.values_mut() { + vec.push(Change::Renumber(renumbering.clone())); + } + self.uf = UnionFind::new(); + renumbering + } +} + +macro_rules! weighted_choose { + ($rng:expr, $($name:ident: $weight:expr => $body:expr),+) => { + { + enum Branches { $( $name, )* } + let weights = [$((Branches::$name, $weight)),+]; + match weights.choose_weighted($rng, |x| x.1).unwrap().0 { + $(Branches::$name => $body),* + } + } + } +} #[test] -fn test() { - let l = |n| Var::from_index(n).as_lit(); - let mut tuf = TrackedUnionFind::::new(); - let mut token = tuf.start_observing(); - tuf.union([l(3), !l(4)]); - tuf.union([l(8), l(7)]); - let mut token2 = tuf.start_observing(); - tuf.union([l(4), l(5)]); - for change in tuf.drain_changes(&mut token).cloned() { - println!("{change:?}"); - } - println!("---"); - tuf.union([!l(5), l(6)]); - tuf.make_repr(l(4).var()); - let renumber: IdVec> = - IdVec::from_vec(vec![Some(l(0)), None, None, Some(l(1)), Some(!l(1)), Some(!l(1)), Some(l(1)), Some(l(2)), Some(l(2))]); - let reverse = Renumbering::get_reverse(&renumber, &tuf.union_find); - dbg!(&renumber, &reverse); - tuf.renumber(renumber, reverse); - tuf.union([l(0), l(1)]); - let mut iter = tuf.drain_changes(&mut token); - println!("{:?}", iter.next()); - iter.stop(); - println!("---"); - for change in tuf.drain_changes(&mut token2).cloned() { - println!("{change:?}"); +fn test_suite() { + let mut u: CTUnionFind = CTUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let mut active_tokens: Vec = Vec::new(); + let max_var = 2000; + let verbosity = 2; + for _ in 0..2000 { + weighted_choose! {&mut rng, + Union: 8.0 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + if verbosity > 1 { + println!("union({a}, {b}) = {result:?}"); + } + }, + Find: 1.0 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + if verbosity > 1 { + println!("find({a}) = {result}"); + } + }, + MakeRepr: 2.0 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + if verbosity > 1 { + println!("make_repr({a})"); + } + }, + StartObserving: 1.0 => { + let token = u.start_observing(); + if verbosity > 0 { + println!("start observing {token:?}"); + } + active_tokens.push(token); + }, + StopObserving: 1.0 => { + if let Some(index) = (0..active_tokens.len()).choose(&mut rng) { + let token = active_tokens.swap_remove(index); + if verbosity > 0 { + println!("stop observing {token:?}"); + } + u.stop_observing(token); + } + }, + CloneToken: 1.0 => { + if let Some(token) = active_tokens.iter().choose(&mut rng) { + let new_token = u.clone_token(token); + if verbosity > 0 { + println!("clone {token:?} -> {new_token:?}"); + } + active_tokens.push(new_token); + } + }, + DrainAllChanges: 2.0 => { + if let Some(token) = active_tokens.iter_mut().choose(&mut rng) { + let mut count = 0; + let mut gen = token.generation; + u.drain_changes_with_fn(token, |changes| { + for change in changes { + if let Change::Renumber(renumbering) = change { + assert_eq!(renumbering.old_generation, gen); + gen = renumbering.new_generation; + } + count += 1; + } + }); + assert_eq!(gen, token.generation); + if verbosity > 0 { + println!("drained all ({count}) changes from {token:?}"); + } + } + }, + DrainSomeChanges: 2.0 => { + if let Some(token) = active_tokens.iter_mut().choose(&mut rng) { + let count = u.drain_some_changes(token, |total| (0..=total).choose(&mut rng).unwrap()); + if verbosity > 0 { + println!("drained {count} changes from {token:?}"); + } + } + }, + Renumber: 0.25 => { + let renumbering = u.renumber(); + if verbosity > 0 { + println!("renumbered all variables, now on generation {} ({} old variables, {} new variables)", + u.dut.generation.0, + renumbering.forward.len(), + renumbering.reverse.len() + ); + } + } + } } } + +#[test] +#[should_panic] +fn test_token_error() { + let mut tuf1: TrackedUnionFind = Default::default(); + let mut tuf2: TrackedUnionFind = Default::default(); + let mut token = tuf1.start_observing(); + tuf2.drain_changes(&mut token); +} \ No newline at end of file diff --git a/dsf/src/tests/test_union_find.rs b/dsf/src/tests/test_union_find.rs index 142a245..2727723 100644 --- a/dsf/src/tests/test_union_find.rs +++ b/dsf/src/tests/test_union_find.rs @@ -143,11 +143,11 @@ impl> CheckedUnionFind { } #[test] -fn test() { +fn test_suite() { let mut u: CheckedUnionFind = CheckedUnionFind::new(); let mut rng = rand_pcg::Pcg64::seed_from_u64(25); let max_var = 2000; - for i in 0..2000 { + for _ in 0..2000 { match rng.gen_range(0..10) { 0..=4 => { let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs index 64bbfe5..03317cd 100644 --- a/dsf/src/tracked_union_find.rs +++ b/dsf/src/tracked_union_find.rs @@ -24,7 +24,6 @@ pub struct Generation(u64); pub struct ObserverId(Id64); /// An `ObserverToken` represents an observer of a `TrackedUnionFind`. -#[derive(Debug)] pub struct ObserverToken { tuf_id: TrackedUnionFindId, generation: Generation, @@ -60,6 +59,16 @@ impl ObserverToken { } } +impl std::fmt::Debug for ObserverToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ObserverToken") + .field("tuf_id", &self.tuf_id.0) + .field("generation", &self.generation.0) + .field("observer_id", &self.observer_id.0) + .finish() + } +} + /// `Renumbering` represents a renumbering of all variables in `TrackedUnionFind`. /// /// A renumbering stores a forward and a reverse mapping, and the old and new generation IDs. @@ -196,12 +205,24 @@ impl std::fmt::Debug for Change { /// A `TrackedUnionFind` augments a [`UnionFind`] structure with change tracking. pub struct TrackedUnionFind { + /// Globally unique ID. tuf_id: TrackedUnionFindId, union_find: UnionFind, + /// Log of changes with new changes appended to the end. + /// Indices into this log are relative positions. + /// The start of this log has relative position 0 and absolute position `log_start`. log: VecDeque>, observer_id_alloc: IdAlloc, + /// This stores absolute positions of each observers in the log, + /// and also allows retrieving the minimum absolute position of any observer. + /// + /// `truncate_log` will reset that minimum position to be at relative position 0. observers: PriorityQueue>, + /// Offset between absolute and relative positions in the log. + /// + /// absolute position = relative position + `log_start` log_start: u64, + /// Incremented on every renumbering. generation: Generation, } @@ -274,16 +295,17 @@ impl> TrackedUnionFind { /// This resets the `UnionFind` to the trivial state (`find(a) == a` for all `a`) and increments the generation ID. /// /// This method will panic in debug mode if said preconditions are not met. - pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) { + pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) -> Arc> { let old_generation = self.generation; let new_generation = Generation(old_generation.0 + 1); self.generation = new_generation; - let renumbering = Renumbering::new(forward, reverse, old_generation, new_generation); + let renumbering = Arc::new(Renumbering::new(forward, reverse, old_generation, new_generation)); debug_assert!(renumbering.is_repr_reduction(&self.union_find)); if !self.observers.is_empty() { - self.log.push_back(Change::Renumber(Arc::new(renumbering))); + self.log.push_back(Change::Renumber(renumbering.clone())); } self.union_find = UnionFind::new(); + renumbering } } @@ -401,7 +423,7 @@ impl TrackedUnionFind { /// /// The method assumes that you will immediately process any `Renumbering` operations in the log /// and will update the token's generation field. - /// + /// /// Returns `true` iff `f` has been called at least once. pub fn drain_changes_with_fn( &mut self, @@ -460,15 +482,15 @@ pub struct DrainChangesMap<'a, 'b, Atom, Elem, F> { impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { /// Returns a reference to the first entry in the token's private log, without deleting it. - /// + /// /// Returns `None` if the log is empty. pub fn peek(&mut self) -> Option<&Change> { self.tuf.log.get(self.rel_pos) } /// Returns a reference to the first entry in the token's private log. The entry will be deleted after its use. - /// + /// /// If this returns a `Renumbering`, it is assumed that you will process it and the token's generation number will be updated. - /// + /// /// Returns `None` if the log is empty. If `next` returned `None`, it will never return any more entries (the iterator is fused). /// /// As reflected by the lifetimes, the API only guarantees that the returned reference until the next call of any method of this iterator. @@ -519,8 +541,7 @@ impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { impl Drop for DrainChanges<'_, '_, Atom, Elem> { fn drop(&mut self) { // mark any renumberings as seen - while self.next().is_some() { - } + while self.next().is_some() {} self.tuf .observer_set_rel_pos(self.token, self.tuf.log.len()); } diff --git a/dsf/src/union_find.rs b/dsf/src/union_find.rs index ebc54ee..504886b 100644 --- a/dsf/src/union_find.rs +++ b/dsf/src/union_find.rs @@ -238,4 +238,4 @@ impl> std::fmt::Debug for UnionFind Date: Tue, 29 Oct 2024 17:44:50 +0000 Subject: [PATCH 6/7] cargo fmt --- dsf/src/element.rs | 12 ++++++------ dsf/src/tests/test_tracked_union_find.rs | 2 +- dsf/src/tests/test_union_find.rs | 2 +- dsf/src/tracked_union_find.rs | 13 +++++++++++-- dsf/src/union_find.rs | 14 +++++++------- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/dsf/src/element.rs b/dsf/src/element.rs index 7bb6809..d224e15 100644 --- a/dsf/src/element.rs +++ b/dsf/src/element.rs @@ -3,21 +3,21 @@ use imctk_lit::{Lit, Var}; /// A trait for "elements" that can be split into an "atom" and a "polarity". -/// +/// /// This lets code generically manipulate variables and literals, and further serves to abstract over their concrete representation. -/// +/// /// The two most common case this trait is used for are: /// 1) The element and the atom are both `Var`. In this case there is only one polarity and the trait implementation is trivial. /// 2) The element is `Lit` and the atom is `Var`. Here there are two polarities (`+` and `-`) to keep track of. -/// +/// /// Mathematically, implementing this trait signifies that elements can be written as pairs `(a, p)` with an atom `a` and a polarity `p`. /// The polarities are assumed to form a group `(P, *, 1)`. The trait operations then correspond to: /// 1) `from_atom(a) = (a, 1)` /// 2) `atom((a, p)) = a` /// 3) `apply_pol_of((a, p), (b, q)) = (a, p * q)` -/// +/// /// Currently, code assumes that `P` is either trivial or isomorphic to `Z_2`. -/// +/// /// Code using this trait may assume the following axioms to hold: /// 1) `from_atom(atom(x)) == x` /// 2) `apply_pol_of(atom(x), x) == x` @@ -54,4 +54,4 @@ impl Element for Lit { fn apply_pol_of(self, other: Self) -> Self { self ^ other.pol() } -} \ No newline at end of file +} diff --git a/dsf/src/tests/test_tracked_union_find.rs b/dsf/src/tests/test_tracked_union_find.rs index 087c5ab..33ba01d 100644 --- a/dsf/src/tests/test_tracked_union_find.rs +++ b/dsf/src/tests/test_tracked_union_find.rs @@ -284,4 +284,4 @@ fn test_token_error() { let mut tuf2: TrackedUnionFind = Default::default(); let mut token = tuf1.start_observing(); tuf2.drain_changes(&mut token); -} \ No newline at end of file +} diff --git a/dsf/src/tests/test_union_find.rs b/dsf/src/tests/test_union_find.rs index 2727723..5f8a53e 100644 --- a/dsf/src/tests/test_union_find.rs +++ b/dsf/src/tests/test_union_find.rs @@ -1,8 +1,8 @@ #![allow(dead_code, missing_docs)] use super::*; -use imctk_lit::{Var, Lit}; use imctk_ids::id_set_seq::IdSetSeq; +use imctk_lit::{Lit, Var}; use rand::prelude::*; use std::collections::{HashSet, VecDeque}; diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs index 03317cd..343f74a 100644 --- a/dsf/src/tracked_union_find.rs +++ b/dsf/src/tracked_union_find.rs @@ -295,11 +295,20 @@ impl> TrackedUnionFind { /// This resets the `UnionFind` to the trivial state (`find(a) == a` for all `a`) and increments the generation ID. /// /// This method will panic in debug mode if said preconditions are not met. - pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) -> Arc> { + pub fn renumber( + &mut self, + forward: IdVec>, + reverse: IdVec, + ) -> Arc> { let old_generation = self.generation; let new_generation = Generation(old_generation.0 + 1); self.generation = new_generation; - let renumbering = Arc::new(Renumbering::new(forward, reverse, old_generation, new_generation)); + let renumbering = Arc::new(Renumbering::new( + forward, + reverse, + old_generation, + new_generation, + )); debug_assert!(renumbering.is_repr_reduction(&self.union_find)); if !self.observers.is_empty() { self.log.push_back(Change::Renumber(renumbering.clone())); diff --git a/dsf/src/union_find.rs b/dsf/src/union_find.rs index 504886b..1279b09 100644 --- a/dsf/src/union_find.rs +++ b/dsf/src/union_find.rs @@ -28,25 +28,25 @@ mod test_union_find; /// To declare two elements as equivalent, use the `union` or `union_full` methods, see their documentation for details on their use. /// /// NB: Since this structure stores atoms in an `IdVec`, the atoms used should ideally be a contiguous set starting at `Atom::MIN_ID`. -/// +/// /// ## Example ## /// ``` /// use imctk_lit::{Var, Lit}; /// use dsf::UnionFind; -/// +/// /// let mut union_find: UnionFind = UnionFind::new(); /// let lit = |n| Var::from_index(n).as_lit(); -/// +/// /// assert_eq!(union_find.find(lit(4)), lit(4)); -/// +/// /// union_find.union([lit(3), lit(4)]); /// assert_eq!(union_find.find(lit(4)), lit(3)); -/// +/// /// union_find.union([lit(1), !lit(2)]); /// union_find.union([lit(2), lit(3)]); /// assert_eq!(union_find.find(lit(1)), lit(1)); /// assert_eq!(union_find.find(lit(4)), !lit(1)); -/// +/// /// ``` pub struct UnionFind { parent: IdVec>, @@ -238,4 +238,4 @@ impl> std::fmt::Debug for UnionFind Date: Tue, 29 Oct 2024 18:08:23 +0000 Subject: [PATCH 7/7] rename dsf to union_find --- Cargo.lock | 26 +++++++++---------- Cargo.toml | 2 +- {dsf => union_find}/Cargo.toml | 2 +- {dsf => union_find}/src/element.rs | 0 {dsf => union_find}/src/lib.rs | 0 .../src/tests/test_tracked_union_find.rs | 0 .../src/tests/test_union_find.rs | 0 {dsf => union_find}/src/tracked_union_find.rs | 0 {dsf => union_find}/src/union_find.rs | 2 +- 9 files changed, 16 insertions(+), 16 deletions(-) rename {dsf => union_find}/Cargo.toml (91%) rename {dsf => union_find}/src/element.rs (100%) rename {dsf => union_find}/src/lib.rs (100%) rename {dsf => union_find}/src/tests/test_tracked_union_find.rs (100%) rename {dsf => union_find}/src/tests/test_union_find.rs (100%) rename {dsf => union_find}/src/tracked_union_find.rs (100%) rename {dsf => union_find}/src/union_find.rs (99%) diff --git a/Cargo.lock b/Cargo.lock index 0da409e..964004e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -252,18 +252,6 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" -[[package]] -name = "dsf" -version = "0.1.0" -dependencies = [ - "atomic", - "imctk-ids", - "imctk-lit", - "priority-queue", - "rand", - "rand_pcg", -] - [[package]] name = "encode_unicode" version = "0.3.6" @@ -469,7 +457,7 @@ name = "imctk-ids" version = "0.1.0" dependencies = [ "bytemuck", - "hashbrown", + "hashbrown 0.14.5", "imctk-derive", "imctk-transparent", "rand", @@ -542,6 +530,18 @@ dependencies = [ "zwohash", ] +[[package]] +name = "imctk_union_find" +version = "0.1.0" +dependencies = [ + "atomic", + "imctk-ids", + "imctk-lit", + "priority-queue", + "rand", + "rand_pcg", +] + [[package]] name = "indenter" version = "0.3.3" diff --git a/Cargo.toml b/Cargo.toml index 6ae2fc0..0c5247d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ members = [ "imctk", "lit", "stable_set", - "dsf", + "union_find", # comment to force multi-line layout ] diff --git a/dsf/Cargo.toml b/union_find/Cargo.toml similarity index 91% rename from dsf/Cargo.toml rename to union_find/Cargo.toml index f4fd701..dfc7004 100644 --- a/dsf/Cargo.toml +++ b/union_find/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "dsf" +name = "imctk_union_find" version = "0.1.0" edition = "2021" diff --git a/dsf/src/element.rs b/union_find/src/element.rs similarity index 100% rename from dsf/src/element.rs rename to union_find/src/element.rs diff --git a/dsf/src/lib.rs b/union_find/src/lib.rs similarity index 100% rename from dsf/src/lib.rs rename to union_find/src/lib.rs diff --git a/dsf/src/tests/test_tracked_union_find.rs b/union_find/src/tests/test_tracked_union_find.rs similarity index 100% rename from dsf/src/tests/test_tracked_union_find.rs rename to union_find/src/tests/test_tracked_union_find.rs diff --git a/dsf/src/tests/test_union_find.rs b/union_find/src/tests/test_union_find.rs similarity index 100% rename from dsf/src/tests/test_union_find.rs rename to union_find/src/tests/test_union_find.rs diff --git a/dsf/src/tracked_union_find.rs b/union_find/src/tracked_union_find.rs similarity index 100% rename from dsf/src/tracked_union_find.rs rename to union_find/src/tracked_union_find.rs diff --git a/dsf/src/union_find.rs b/union_find/src/union_find.rs similarity index 99% rename from dsf/src/union_find.rs rename to union_find/src/union_find.rs index 1279b09..8473d68 100644 --- a/dsf/src/union_find.rs +++ b/union_find/src/union_find.rs @@ -32,7 +32,7 @@ mod test_union_find; /// ## Example ## /// ``` /// use imctk_lit::{Var, Lit}; -/// use dsf::UnionFind; +/// use imctk_union_find::UnionFind; /// /// let mut union_find: UnionFind = UnionFind::new(); /// let lit = |n| Var::from_index(n).as_lit();