From e66f420260331dd598ca826bcb56257ed3dcd157 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Fri, 11 Oct 2024 17:49:35 +0100 Subject: [PATCH] add dsf crate --- Cargo.lock | 34 +++ Cargo.toml | 1 + dsf/Cargo.toml | 20 ++ dsf/src/lib.rs | 2 + dsf/src/tracked_union_find.rs | 474 ++++++++++++++++++++++++++++++++++ dsf/src/union_find.rs | 319 +++++++++++++++++++++++ ids/src/id_vec.rs | 2 +- lit/Cargo.toml | 1 + lit/src/lit.rs | 2 + lit/src/var.rs | 2 + 10 files changed, 856 insertions(+), 1 deletion(-) create mode 100644 dsf/Cargo.toml create mode 100644 dsf/src/lib.rs create mode 100644 dsf/src/tracked_union_find.rs create mode 100644 dsf/src/union_find.rs diff --git a/Cargo.lock b/Cargo.lock index 4a8b6f7..75eb901 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "atomic" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +dependencies = [ + "bytemuck", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -229,6 +238,19 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "dsf" +version = "0.1.0" +dependencies = [ + "atomic", + "bytemuck", + "imctk-ids", + "imctk-lit", + "priority-queue", + "rand", + "rand_pcg", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -470,6 +492,7 @@ dependencies = [ name = "imctk-lit" version = "0.1.0" dependencies = [ + "bytemuck", "flussab-aiger", "hashbrown 0.14.5", "imctk-derive", @@ -670,6 +693,17 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "priority-queue" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714c75db297bc88a63783ffc6ab9f830698a6705aa0201416931759ef4c8183d" +dependencies = [ + "autocfg", + "equivalent", + "indexmap", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 1fa671e..6ae2fc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "imctk", "lit", "stable_set", + "dsf", # comment to force multi-line layout ] diff --git a/dsf/Cargo.toml b/dsf/Cargo.toml new file mode 100644 index 0000000..bee1698 --- /dev/null +++ b/dsf/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "dsf" +version = "0.1.0" +edition = "2021" + +publish = false + +[dependencies] +imctk-ids = { version = "0.1.0", path = "../ids" } +imctk-lit = { version = "0.1.0", path = "../lit" } +bytemuck = "*" +atomic = "0.6" +priority-queue = "*" + +[dev-dependencies] +rand = "*" +rand_pcg = "*" + +[lints] +workspace = true diff --git a/dsf/src/lib.rs b/dsf/src/lib.rs new file mode 100644 index 0000000..0505e00 --- /dev/null +++ b/dsf/src/lib.rs @@ -0,0 +1,2 @@ +pub mod union_find; +pub mod tracked_union_find; \ No newline at end of file diff --git a/dsf/src/tracked_union_find.rs b/dsf/src/tracked_union_find.rs new file mode 100644 index 0000000..0f985d4 --- /dev/null +++ b/dsf/src/tracked_union_find.rs @@ -0,0 +1,474 @@ +#![allow(missing_docs)] +#![allow(clippy::type_complexity)] +use std::{cmp::Reverse, collections::VecDeque, mem::ManuallyDrop, sync::Arc}; + +use atomic::Atomic; +use bytemuck::NoUninit; +use imctk_ids::{id_vec::IdVec, Id, Id64}; +use priority_queue::PriorityQueue; + +use crate::union_find::{Element, UnionFind}; + +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct TrackedUnionFindId(Id64); +// SAFETY: trust me bro +unsafe impl NoUninit for TrackedUnionFindId {} +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct Generation(u64); +#[derive(Id, Debug)] +#[repr(transparent)] +pub struct ObserverId(Id64); +// SAFETY: trust me bro +unsafe impl NoUninit for ObserverId {} + +#[derive(Debug)] +pub struct ObserverToken { + tuf_id: TrackedUnionFindId, + generation: Generation, + observer_id: ObserverId, +} + +pub struct Renumbering { + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, +} + +impl std::fmt::Debug for Renumbering { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Renumbering") + .field("forward", &self.forward) + .field("reverse", &self.reverse) + .field("old_generation", &self.old_generation) + .field("new_generation", &self.new_generation) + .finish() + } +} + +impl + NoUninit> Renumbering { + pub fn get_reverse(forward: &IdVec>, union_find: &UnionFind) -> IdVec { + let mut reverse: IdVec> = IdVec::default(); + for (old, &new_opt) in forward { + if let Some(new) = new_opt { + reverse + .grow_for_key(new.atom()) + .replace(union_find.find(Elem::from_atom(old)).apply_pol_of(new)); + } + } + IdVec::from_vec(reverse.iter().map(|x| x.1.unwrap()).collect()) + } + pub fn is_inverse(forward: &IdVec>, reverse: &IdVec) -> bool { + reverse.iter().all(|(new, &old)| { + if let Some(&Some(e)) = forward.get(old.atom()) { + Elem::from_atom(new) == e.apply_pol_of(old) + } else { + false + } + }) + } + pub fn new_unchecked( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + Renumbering { + forward, + reverse, + old_generation, + new_generation, + } + } + pub fn new( + forward: IdVec>, + reverse: IdVec, + old_generation: Generation, + new_generation: Generation, + ) -> Self { + debug_assert!(new_generation > old_generation); + debug_assert!(Self::is_inverse(&forward, &reverse)); + Self::new_unchecked(forward, reverse, old_generation, new_generation) + } + pub fn old_to_new(&self, old: Elem) -> Option { + self.forward + .get(old.atom()) + .copied() + .flatten() + .map(|e| e.apply_pol_of(old)) + } + pub fn new_to_old(&self, new: Elem) -> Option { + self.reverse.get(new.atom()).map(|&e| e.apply_pol_of(new)) + } + pub fn is_repr_reduction(&self, union_find: &UnionFind) -> bool { + union_find.lowest_unused_atom() <= self.forward.next_unused_key() + && self.forward.iter().all(|(old, &new)| { + let repr = union_find.find(Elem::from_atom(old)); + let repr_new = self.old_to_new(repr); + repr_new == new + }) + } +} + +#[derive(Clone)] +pub enum Change { + Union { new_repr: Atom, merged_repr: Elem }, + MakeRepr { new_repr: Atom, old_repr: Elem }, + Renumber(Arc>), +} + +impl std::fmt::Debug for Change { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Union { + new_repr, + merged_repr, + } => f + .debug_struct("Union") + .field("new_repr", new_repr) + .field("merged_repr", merged_repr) + .finish(), + Self::MakeRepr { new_repr, old_repr } => f + .debug_struct("MakeRepr") + .field("new_repr", new_repr) + .field("old_repr", old_repr) + .finish(), + Self::Renumber(arg0) => f.debug_tuple("Renumber").field(arg0).finish(), + } + } +} + +pub struct TrackedUnionFind { + tuf_id: TrackedUnionFindId, + union_find: UnionFind, + log: VecDeque>, + observer_id_alloc: IdAlloc, + observers: PriorityQueue>, + log_start: u64, + generation: Generation, +} + +pub struct IdAlloc { + counter: Atomic, +} + +impl Default for IdAlloc { + fn default() -> Self { + Self::new() + } +} + +impl IdAlloc { + const fn new() -> Self { + Self { + counter: Atomic::new(T::MIN_ID), + } + } + pub fn alloc_block(&self, n: usize) -> T { + use atomic::Ordering::Relaxed; + debug_assert!(n > 0); + let mut current_id = self.counter.load(Relaxed); + loop { + let new_id = current_id + .id_index() + .checked_add(n) + .and_then(T::try_from_id_index) + .expect("not enough IDs remaining"); + match self + .counter + .compare_exchange_weak(current_id, new_id, Relaxed, Relaxed) + { + Ok(_) => return current_id, + Err(id) => current_id = id, + } + } + } + pub fn alloc(&self) -> T { + self.alloc_block(1) + } +} + +static TUF_ID_ALLOC: IdAlloc = IdAlloc::new(); + +impl Default for TrackedUnionFind { + fn default() -> Self { + Self { + tuf_id: TUF_ID_ALLOC.alloc(), + union_find: Default::default(), + log: Default::default(), + observer_id_alloc: Default::default(), + observers: Default::default(), + log_start: 0, + generation: Generation(0), + } + } +} + +impl + NoUninit> TrackedUnionFind { + pub fn find(&self, elem: Elem) -> Elem { + self.union_find.find(elem) + } + pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, roots) = self.union_find.union_full(lits); + if ok && !self.observers.is_empty() { + let new_repr = roots[0].atom(); + let merged_repr = roots[1].apply_pol_of(roots[0]); + self.log.push_back(Change::Union { + new_repr, + merged_repr, + }) + } + (ok, roots) + } + pub fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + pub fn make_repr(&mut self, new_repr: Atom) -> Elem { + let old_repr = self.union_find.make_repr(new_repr); + if old_repr.atom() != new_repr && !self.observers.is_empty() { + self.log.push_back(Change::MakeRepr { new_repr, old_repr }); + } + old_repr + } + pub fn renumber(&mut self, forward: IdVec>, reverse: IdVec) { + let old_generation = self.generation; + let new_generation = Generation(old_generation.0 + 1); + self.generation = new_generation; + let renumbering = Renumbering::new(forward, reverse, old_generation, new_generation); + debug_assert!(renumbering.is_repr_reduction(&self.union_find)); + if !self.observers.is_empty() { + self.log.push_back(Change::Renumber(Arc::new(renumbering))); + } + self.union_find = UnionFind::new(); + } +} + +impl TrackedUnionFind { + pub fn new() -> Self { + Self::default() + } + fn log_end(&self) -> u64 { + self.log_start + self.log.len() as u64 + } + pub fn start_observing(&mut self) -> ObserverToken { + let observer_id = self.observer_id_alloc.alloc(); + self.observers.push(observer_id, Reverse(self.log_end())); + ObserverToken { + tuf_id: self.tuf_id, + generation: self.generation, + observer_id, + } + } + pub fn clone_token(&mut self, token: &ObserverToken) -> ObserverToken { + assert!(token.tuf_id == self.tuf_id); + let new_observer_id = self.observer_id_alloc.alloc(); + let pos = *self.observers.get_priority(&token.observer_id).unwrap(); + self.observers.push(new_observer_id, pos); + ObserverToken { + tuf_id: self.tuf_id, + generation: token.generation, + observer_id: new_observer_id, + } + } + pub fn stop_observing(&mut self, token: ObserverToken) { + assert!(token.tuf_id == self.tuf_id); + self.observers.remove(&token.observer_id); + self.truncate_log(); + } + fn truncate_log(&mut self) { + if let Some((_, &Reverse(new_start))) = self.observers.peek() { + if new_start > self.log_start { + let delete = (new_start - self.log_start).try_into().unwrap(); + drop(self.log.drain(0..delete)); + println!("dropped {delete} entries"); + self.log_start = new_start; + } + } else { + self.log_start = self.log_end(); + self.log.clear(); + println!("dropped all entries"); + } + } + fn observer_rel_pos(&self, token: &ObserverToken) -> usize { + assert!(token.tuf_id == self.tuf_id); + let abs_pos = self.observers.get_priority(&token.observer_id).unwrap().0; + debug_assert!(abs_pos >= self.log_start); + (abs_pos - self.log_start).try_into().unwrap() + } + fn observer_inc_pos(&mut self, token: &ObserverToken, by: u64) { + let (min, max) = (self.log_start, self.log_end()); + self.observers + .change_priority_by(&token.observer_id, |pos| { + let new = pos.0 + by; + debug_assert!(new >= min && new <= max); + *pos = Reverse(new); + }); + self.truncate_log(); + } + fn observer_set_rel_pos(&mut self, token: &ObserverToken, rel_pos: usize) { + debug_assert!(rel_pos <= self.log.len()); + let abs_pos = self.log_start + rel_pos as u64; + self.observers + .change_priority(&token.observer_id, Reverse(abs_pos)); + self.truncate_log(); + } + fn change_slices( + &self, + token: &ObserverToken, + ) -> ( + &[Change], + &[Change], + &UnionFind, + ) { + let rel_pos = self.observer_rel_pos(token); + let (first, second) = self.log.as_slices(); + if rel_pos >= first.len() { + (&second[rel_pos - first.len()..], &[], &self.union_find) + } else { + (&first[rel_pos..], second, &self.union_find) + } + } + fn observer_has_seen(&self, token: &mut ObserverToken, changes: &[Change]) { + for change in changes { + if let Change::Renumber(renumbering) = change { + debug_assert!(token.generation == renumbering.old_generation); + token.generation = renumbering.new_generation; + } + } + } + pub fn drain_changes_with_fn( + &mut self, + token: &mut ObserverToken, + mut f: impl FnMut(&[Change], &UnionFind), + ) -> bool { + let (first, second, union_find) = self.change_slices(token); + if !first.is_empty() { + f(first, union_find); + self.observer_has_seen(token, first); + if !second.is_empty() { + f(second, union_find); + self.observer_has_seen(token, second); + } + let drained = first.len() + second.len(); + self.observer_inc_pos(token, drained as u64); + true + } else { + false + } + } + pub fn drain_changes<'a>( + &mut self, + token: &'a mut ObserverToken, + ) -> DrainChanges<'_, 'a, Atom, Elem> { + assert!(token.tuf_id == self.tuf_id); + let rel_pos = self.observer_rel_pos(token); + DrainChanges { + tuf: self, + token, + rel_pos, + } + } +} + +pub struct DrainChanges<'a, 'b, Atom, Elem> { + tuf: &'a mut TrackedUnionFind, + token: &'b mut ObserverToken, + rel_pos: usize, +} + +pub struct DrainChangesMap<'a, 'b, Atom, Elem, F> { + inner: DrainChanges<'a, 'b, Atom, Elem>, + f: F, +} + +impl<'a, 'b, Atom, Elem> DrainChanges<'a, 'b, Atom, Elem> { + pub fn peek(&mut self) -> Option<&Change> { + self.tuf.log.get(self.rel_pos) + } + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<&Change> { + let ret = self.tuf.log.get(self.rel_pos); + if let Some(change) = ret { + self.tuf + .observer_has_seen(self.token, std::slice::from_ref(change)); + self.rel_pos += 1; + } + ret + } + pub fn stop(self) { + self.tuf.observer_set_rel_pos(self.token, self.rel_pos); + let _ = ManuallyDrop::new(self); + } + pub fn size_hint(&self) -> (usize, Option) { + let count = self.tuf.log.len() - self.rel_pos; + (count, Some(count)) + } + pub fn map(self, f: F) -> DrainChangesMap<'a, 'b, Atom, Elem, F> + where + F: FnMut(&Change) -> B, + { + DrainChangesMap { inner: self, f } + } +} + +impl<'a, 'b, Atom: Clone, Elem: Clone> DrainChanges<'a, 'b, Atom, Elem> { + pub fn cloned( + self, + ) -> DrainChangesMap<'a, 'b, Atom, Elem, fn(&Change) -> Change> { + self.map(|x| x.clone()) + } +} + +impl Drop for DrainChanges<'_, '_, Atom, Elem> { + fn drop(&mut self) { + self.tuf + .observer_set_rel_pos(self.token, self.tuf.log.len()); + } +} + +impl Iterator for DrainChangesMap<'_, '_, Atom, Elem, F> +where + F: FnMut(&Change) -> B, +{ + type Item = B; + + fn next(&mut self) -> Option { + self.inner.next().map(&mut self.f) + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[test] +fn test() { + use imctk_lit::{Lit, Var}; + let l = |n| Var::from_index(n).as_lit(); + let mut tuf = TrackedUnionFind::::new(); + let mut token = tuf.start_observing(); + tuf.union([l(3), !l(4)]); + tuf.union([l(8), l(7)]); + let mut token2 = tuf.start_observing(); + tuf.union([l(4), l(5)]); + for change in tuf.drain_changes(&mut token).cloned() { + println!("{change:?}"); + } + println!("---"); + tuf.union([!l(5), l(6)]); + tuf.make_repr(l(4).var()); + let renumber: IdVec> = + IdVec::from_vec(vec![Some(l(0)), None, None, Some(l(1)), Some(!l(1)), Some(!l(1)), Some(l(1)), Some(l(2)), Some(l(2))]); + let reverse = Renumbering::get_reverse(&renumber, &tuf.union_find); + dbg!(&renumber, &reverse); + tuf.renumber(renumber, reverse); + tuf.union([l(0), l(1)]); + let mut iter = tuf.drain_changes(&mut token); + println!("{:?}", iter.next()); + iter.stop(); + println!("---"); + for change in tuf.drain_changes(&mut token2).cloned() { + println!("{change:?}"); + } +} diff --git a/dsf/src/union_find.rs b/dsf/src/union_find.rs new file mode 100644 index 0000000..5036432 --- /dev/null +++ b/dsf/src/union_find.rs @@ -0,0 +1,319 @@ +#![allow(missing_docs)] +use std::sync::atomic::Ordering; + +use atomic::Atomic; +use bytemuck::NoUninit; +use imctk_ids::{id_vec::IdVec, Id}; +use imctk_lit::{Lit, Var}; + +pub trait Element { + fn from_atom(atom: Atom) -> Self; + fn atom(self) -> Atom; + fn apply_pol_of(self, other: Self) -> Self; +} + +impl Element for T { + fn from_atom(atom: T) -> Self { + atom + } + fn atom(self) -> T { + self + } + fn apply_pol_of(self, _other: T) -> Self { + self + } +} + +impl Element for Lit { + fn from_atom(atom: Var) -> Self { + atom.as_lit() + } + fn atom(self) -> Var { + self.var() + } + fn apply_pol_of(self, other: Self) -> Self { + self ^ other.pol() + } +} + +pub struct UnionFind { + parent: IdVec>, +} + +impl Default for UnionFind { + fn default() -> Self { + UnionFind { + parent: Default::default(), + } + } +} + +impl Clone for UnionFind { + fn clone(&self) -> Self { + let new_parent = self + .parent + .values() + .iter() + .map(|p| Atomic::new(p.load(Ordering::Relaxed))) + .collect(); + Self { + parent: IdVec::from_vec(new_parent), + } + } +} + +impl + NoUninit> UnionFind { + pub fn new() -> Self { + UnionFind::default() + } + fn read_parent(&self, atom: Atom) -> Elem { + self.parent + .get(atom) + .map(|p| p.load(Ordering::Relaxed)) + .unwrap_or(Elem::from_atom(atom)) + } + fn update_parent(&self, atom: Atom, parent: Elem) { + let Some(parent_ref) = self.parent.get(atom) else { + panic!("shouldn't happen: update_parent called with out of bounds argument"); + }; + parent_ref.store(parent, Ordering::Relaxed); + } + fn write_parent(&mut self, atom: Atom, parent: Elem) { + if let Some(parent_ref) = self.parent.get(atom) { + parent_ref.store(parent, Ordering::Relaxed); + } else { + debug_assert!(self.parent.next_unused_key() <= atom); + while self.parent.next_unused_key() < atom { + self.parent + .push(Atomic::new(Elem::from_atom(self.parent.next_unused_key()))); + } + self.parent.push(Atomic::new(parent)); + } + } + fn find_root(&self, mut elem: Elem) -> Elem { + loop { + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if elem == parent { + return elem; + } + debug_assert!(elem.atom() != parent.atom()); + elem = parent; + } + } + fn update_root(&self, mut elem: Elem, root: Elem) { + loop { + let parent = self.read_parent(elem.atom()).apply_pol_of(elem); + if parent == root { + break; + } + self.update_parent(elem.atom(), root.apply_pol_of(elem)); + elem = parent; + } + } + pub fn find(&self, lit: Elem) -> Elem { + let root = self.find_root(lit); + self.update_root(lit, root); + root + } + pub fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let [a, b] = lits; + let ra = self.find(a); + let rb = self.find(b); + if ra.atom() == rb.atom() { + (false, [ra, rb]) + } else { + self.write_parent(rb.atom(), ra.apply_pol_of(rb)); + (true, [ra, rb]) + } + } + pub fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + pub fn make_repr(&mut self, atom: Atom) -> Elem { + let root = self.find(Elem::from_atom(atom)); + self.write_parent(atom, Elem::from_atom(atom)); + self.write_parent(root.atom(), Elem::from_atom(atom).apply_pol_of(root)); + root + } + pub fn lowest_unused_atom(&self) -> Atom { + self.parent.next_unused_key() + } +} + +#[cfg(test)] +#[allow(dead_code)] +mod tests { + use super::*; + use imctk_ids::id_set_seq::IdSetSeq; + use rand::prelude::*; + use std::collections::{HashSet, VecDeque}; + + #[derive(Default)] + struct CheckedUnionFind { + dut: UnionFind, + equivs: IdSetSeq, + } + + impl + NoUninit> UnionFind { + fn debug_print_tree( + children: &IdVec>, + atom: Atom, + prefix: &str, + self_char: &str, + further_char: &str, + pol: bool, + ) { + println!( + "{prefix}{self_char}{}{:?}", + if pol { "!" } else { "" }, + atom + ); + let my_children = children.get(atom).unwrap(); + for (index, &child) in my_children.iter().enumerate() { + let last = index == my_children.len() - 1; + let self_char = if last { "└" } else { "├" }; + let next_further_char = if last { " " } else { "│" }; + Self::debug_print_tree( + children, + child.atom(), + &(prefix.to_string() + further_char), + self_char, + next_further_char, + pol ^ (child != Elem::from_atom(child.atom())), + ); + } + } + fn debug_print(&self) { + let mut children: IdVec> = Default::default(); + for atom in self.parent.keys() { + let parent = self.read_parent(atom); + children.grow_for_key(atom); + if atom != parent.atom() { + children + .grow_for_key(parent.atom()) + .push(Elem::from_atom(atom).apply_pol_of(parent)); + } else { + assert!(Elem::from_atom(atom) == parent); + } + } + for atom in self.parent.keys() { + if atom == self.read_parent(atom).atom() { + Self::debug_print_tree(&children, atom, "", "", " ", false); + } + } + } + } + #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] + enum VarRel { + Equiv, + AntiEquiv, + NotEquiv, + } + + impl + NoUninit> CheckedUnionFind { + fn new() -> Self { + CheckedUnionFind { + dut: Default::default(), + equivs: Default::default(), + } + } + fn ref_equal(&mut self, start: Elem, goal: Elem) -> VarRel { + let mut seen: HashSet = Default::default(); + let mut queue: VecDeque = [start].into(); + while let Some(place) = queue.pop_front() { + if place.atom() == goal.atom() { + if place == goal { + return VarRel::Equiv; + } else { + return VarRel::AntiEquiv; + } + } + seen.insert(place.atom()); + for &next in self.equivs.grow_for(place.atom()).iter() { + if !seen.contains(&next.atom()) { + queue.push_back(next.apply_pol_of(place)); + } + } + } + VarRel::NotEquiv + } + fn find(&mut self, lit: Elem) -> Elem { + let out = self.dut.find(lit); + assert!(self.ref_equal(lit, out) == VarRel::Equiv); + out + } + fn union_full(&mut self, lits: [Elem; 2]) -> (bool, [Elem; 2]) { + let (ok, [ra, rb]) = self.dut.union_full(lits); + assert_eq!(self.ref_equal(lits[0], ra), VarRel::Equiv); + assert_eq!(self.ref_equal(lits[1], rb), VarRel::Equiv); + assert_eq!(ok, self.ref_equal(lits[0], lits[1]) == VarRel::NotEquiv); + assert_eq!(self.dut.find_root(lits[0]), ra); + if ok { + assert_eq!(self.dut.find_root(lits[1]), ra); + self.equivs + .grow_for(lits[0].atom()) + .insert(lits[1].apply_pol_of(lits[0])); + self.equivs + .grow_for(lits[1].atom()) + .insert(lits[0].apply_pol_of(lits[1])); + } else { + assert_eq!(self.dut.find_root(lits[1]).atom(), ra.atom()); + } + (ok, [ra, rb]) + } + fn union(&mut self, lits: [Elem; 2]) -> bool { + self.union_full(lits).0 + } + fn make_repr(&mut self, lit: Atom) { + self.dut.make_repr(lit); + assert_eq!( + self.dut.find_root(Elem::from_atom(lit)), + Elem::from_atom(lit) + ); + self.check(); + } + fn check(&mut self) { + for atom in self.dut.parent.keys() { + let parent = self.dut.read_parent(atom); + assert_eq!(self.ref_equal(Elem::from_atom(atom), parent), VarRel::Equiv); + let root = self.dut.find_root(Elem::from_atom(atom)); + for &child in self.equivs.grow_for(atom).iter() { + assert_eq!(root, self.dut.find_root(child)); + } + } + } + } + + #[test] + fn test() { + let mut u: CheckedUnionFind = CheckedUnionFind::new(); + let mut rng = rand_pcg::Pcg64::seed_from_u64(25); + let max_var = 2000; + for i in 0..2000 { + match rng.gen_range(0..10) { + 0..=4 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let b = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.union_full([a, b]); + println!("union({a}, {b}) = {result:?}"); + } + 5..=7 => { + let a = Lit::from_code(rng.gen_range(0..=2 * max_var + 1)); + let result = u.find(a); + println!("find({a}) = {result}"); + } + 8 => { + u.check(); + } + 9 => { + let a = Var::from_index(rng.gen_range(0..=max_var)); + u.make_repr(a); + println!("make_repr({a})"); + } + _ => {} + } + } + u.check(); + //u.dut.debug_print(); + } +} diff --git a/ids/src/id_vec.rs b/ids/src/id_vec.rs index ff04a19..9c40049 100644 --- a/ids/src/id_vec.rs +++ b/ids/src/id_vec.rs @@ -312,7 +312,7 @@ impl Clone for IdVec { } } -impl Default for IdVec { +impl Default for IdVec { #[inline(always)] fn default() -> Self { Self { diff --git a/lit/Cargo.toml b/lit/Cargo.toml index 925485a..1b29ee4 100644 --- a/lit/Cargo.toml +++ b/lit/Cargo.toml @@ -19,3 +19,4 @@ imctk-util = { version = "0.1.0", path = "../util" } log = "0.4.21" table_seq = { version = "0.1.0", path = "../table_seq" } zwohash = "0.1.2" +bytemuck = "*" \ No newline at end of file diff --git a/lit/src/lit.rs b/lit/src/lit.rs index 409869c..0aa1a31 100644 --- a/lit/src/lit.rs +++ b/lit/src/lit.rs @@ -24,6 +24,8 @@ use super::{pol::Pol, var::Var}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Lit(Id32); +unsafe impl bytemuck::NoUninit for Lit {} + /// Ensure that there is an even number of literals #[allow(clippy::assertions_on_constants)] const _: () = { diff --git a/lit/src/var.rs b/lit/src/var.rs index ce9fc3a..473ebc4 100644 --- a/lit/src/var.rs +++ b/lit/src/var.rs @@ -9,6 +9,8 @@ use super::{lit::Lit, pol::Pol}; #[derive(Id, SubtypeCast, NewtypeCast)] pub struct Var(GenericId<{ Lit::MAX_ID_INDEX / 2 }, ::BaseId>); +unsafe impl bytemuck::NoUninit for Var {} + impl std::fmt::Debug for Var { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f)