Skip to content

Commit 7ced8c7

Browse files
committed
Partition refinement data structure
1 parent cb76898 commit 7ced8c7

File tree

4 files changed

+335
-0
lines changed

4 files changed

+335
-0
lines changed

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

util/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ version = "0.1.0"
44
edition = "2021"
55

66
[dependencies]
7+
imctk-ids = { version = "0.1.0", path = "../ids" }

util/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
#[macro_use]
77
pub mod give_take;
88

9+
pub mod partition_refinement;
910
pub mod unordered_pair;
1011
pub mod vec_sink;

util/src/partition_refinement.rs

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
#![allow(missing_docs)] // TODO document
2+
3+
use std::cmp::Ordering;
4+
5+
use imctk_ids::{id_vec::IdVec, Id, Id32, IdRange};
6+
7+
#[derive(Id, Debug)]
8+
#[repr(transparent)]
9+
pub struct ClassId(Id32);
10+
11+
#[derive(Id, Debug)]
12+
#[repr(transparent)]
13+
pub struct MemberId(Id32);
14+
15+
#[derive(Default)]
16+
pub struct PartitionRefinement {
17+
class_from_member: IdVec<MemberId, ClassId>,
18+
member_from_class: IdVec<ClassId, MemberId>,
19+
}
20+
21+
impl PartitionRefinement {
22+
pub fn new(len: usize) -> Self {
23+
let member_from_class = IdVec::from_vec(IdRange::from_index_range(0..len).iter().collect());
24+
let mut class_from_member = IdVec::from_vec(vec![ClassId::MIN_ID; len]);
25+
if len > 0 {
26+
class_from_member[MemberId::MIN_ID] = ClassId::from_id_index(len - 1);
27+
}
28+
29+
Self {
30+
class_from_member,
31+
member_from_class,
32+
}
33+
}
34+
35+
pub fn class_of_member(&self, member: MemberId) -> ClassId {
36+
self.canonical_class(self.class_from_member[member])
37+
}
38+
39+
fn canonical_class(&self, class: ClassId) -> ClassId {
40+
let repr_member = self.member_from_class[class];
41+
let repr_class = self.class_from_member[repr_member];
42+
class.min(repr_class)
43+
}
44+
45+
pub fn is_class(&self, class: ClassId) -> bool {
46+
self.canonical_class(class) == class
47+
}
48+
49+
fn unchecked_class_len(&self, class: ClassId) -> usize {
50+
let end = self.class_from_member[self.member_from_class[class]];
51+
end.id_index() - class.id_index() + 1
52+
}
53+
54+
pub fn class_len(&self, class: ClassId) -> usize {
55+
if !self.is_class(class) {
56+
return 0;
57+
}
58+
self.unchecked_class_len(class)
59+
}
60+
61+
pub fn members_in_class(&self, class: ClassId) -> &[MemberId] {
62+
let len = self.class_len(class);
63+
&self.member_from_class.values()[class.id_index()..][..len]
64+
}
65+
66+
pub fn partition_class(
67+
&mut self,
68+
class: ClassId,
69+
mut predicate: impl FnMut(MemberId) -> bool,
70+
) -> Option<ClassId> {
71+
let len = self.class_len(class);
72+
if len < 2 {
73+
return None;
74+
}
75+
76+
// Temporarily remove the length marker while we're permuting class members
77+
self.class_from_member[self.member_from_class[class]] = class;
78+
79+
let start = class.id_index();
80+
let end = start + len;
81+
82+
let mut left = start;
83+
let mut right = end - 1;
84+
85+
let split_at = loop {
86+
while left <= right && predicate(self.member_from_class.values()[left]) {
87+
left += 1
88+
}
89+
while left < right && !predicate(self.member_from_class.values()[right]) {
90+
right -= 1
91+
}
92+
if left >= right {
93+
break left;
94+
}
95+
self.member_from_class.values_mut().swap(left, right);
96+
};
97+
98+
let end_marker = ClassId::from_id_index(end - 1);
99+
100+
if split_at == start || split_at == end {
101+
self.class_from_member[self.member_from_class[class]] = end_marker;
102+
return None;
103+
}
104+
105+
let new_end_marker = ClassId::from_id_index(split_at - 1);
106+
let new_class = ClassId::from_id_index(split_at);
107+
108+
self.class_from_member[self.member_from_class[class]] = new_end_marker;
109+
110+
self.class_from_member[self.member_from_class.values()[split_at]] = end_marker;
111+
for i in split_at + 1..end {
112+
self.class_from_member[self.member_from_class.values()[i]] = new_class;
113+
}
114+
115+
Some(new_class)
116+
}
117+
118+
pub fn permute_class<R>(
119+
&mut self,
120+
class: ClassId,
121+
permute: impl FnOnce(&mut [MemberId]) -> R,
122+
) -> R {
123+
let len = self.class_len(class);
124+
if len == 0 {
125+
return permute(&mut []);
126+
}
127+
128+
let start = class.id_index();
129+
let end = start + len;
130+
131+
let end_marker = ClassId::from_id_index(end - 1);
132+
133+
// Temporarily remove the length marker while we're permuting class members
134+
self.class_from_member[self.member_from_class[class]] = class;
135+
136+
let result = permute(&mut self.member_from_class.values_mut()[start..end]);
137+
138+
self.class_from_member[self.member_from_class[class]] = end_marker;
139+
140+
result
141+
}
142+
143+
pub fn split_at(&mut self, class: ClassId, keep: usize) -> Option<ClassId> {
144+
if keep == 0 {
145+
return None;
146+
}
147+
let len = self.class_len(class);
148+
if keep >= len {
149+
return None;
150+
}
151+
152+
let start = class.id_index();
153+
let end = start + len;
154+
155+
let split_at = start + keep;
156+
157+
let end_marker = ClassId::from_id_index(end - 1);
158+
let new_end_marker = ClassId::from_id_index(split_at - 1);
159+
let new_class = ClassId::from_id_index(split_at);
160+
161+
self.class_from_member[self.member_from_class[class]] = new_end_marker;
162+
163+
self.class_from_member[self.member_from_class.values()[split_at]] = end_marker;
164+
for i in split_at + 1..end {
165+
self.class_from_member[self.member_from_class.values()[i]] = new_class;
166+
}
167+
168+
Some(new_class)
169+
}
170+
171+
pub fn multiway_split_by(
172+
&mut self,
173+
class: ClassId,
174+
mut split_between: impl FnMut(MemberId, MemberId) -> bool,
175+
mut new_class: impl FnMut(&mut Self, ClassId),
176+
) {
177+
let mut len = self.class_len(class);
178+
let start = class.id_index();
179+
180+
while len >= 2 {
181+
let mut pos = len - 1;
182+
183+
loop {
184+
if split_between(
185+
self.member_from_class.values()[start + pos - 1],
186+
self.member_from_class.values()[start + pos],
187+
) {
188+
let new = self.split_at(class, pos).unwrap();
189+
new_class(self, new);
190+
len = pos;
191+
break;
192+
}
193+
pos -= 1;
194+
if pos == 0 {
195+
return;
196+
}
197+
}
198+
}
199+
}
200+
201+
pub fn multiway_unstable_sort_and_split_by(
202+
&mut self,
203+
class: ClassId,
204+
mut compare: impl FnMut(MemberId, MemberId) -> Ordering,
205+
new_class: impl FnMut(&mut Self, ClassId),
206+
) {
207+
self.permute_class(class, |members| {
208+
members.sort_unstable_by(|&a, &b| compare(a, b));
209+
});
210+
self.multiway_split_by(class, |a, b| compare(a, b).is_ne(), new_class)
211+
}
212+
213+
pub fn members(&mut self) -> &[MemberId] {
214+
self.member_from_class.values()
215+
}
216+
}
217+
218+
#[cfg(test)]
219+
mod tests {
220+
221+
use std::mem::swap;
222+
223+
use imctk_ids::id_index_set::IdIndexSet;
224+
225+
use super::*;
226+
227+
#[test]
228+
pub fn test_partition_class() {
229+
// This isn't a good way to compute primes, but not a bad way to test the partition
230+
// refinement data structure.
231+
let mut partition = PartitionRefinement::new(100);
232+
233+
let mut current_class = ClassId::MIN_ID;
234+
235+
let mut primes = vec![];
236+
237+
current_class = partition
238+
.partition_class(current_class, |member| member.id_index() < 2)
239+
.expect("split class");
240+
241+
assert!(partition
242+
.members_in_class(current_class)
243+
.iter()
244+
.all(|&member| member.id_index() >= 2));
245+
246+
for _ in 0..40 {
247+
let target = *partition
248+
.members_in_class(current_class)
249+
.iter()
250+
.min()
251+
.expect("non-empty class");
252+
253+
primes.push(target.id_index());
254+
255+
let Some(refined_class) = partition.partition_class(current_class, |member| {
256+
member.id_index() % target.id_index() == 0
257+
}) else {
258+
break;
259+
};
260+
261+
current_class = refined_class
262+
}
263+
264+
assert_eq!(
265+
primes,
266+
[
267+
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79,
268+
83, 89, 97
269+
]
270+
);
271+
}
272+
273+
#[test]
274+
pub fn test_multiway_partition_class() {
275+
let mut strings: IdIndexSet<MemberId, &'static str> = Default::default();
276+
let common = [
277+
"the", "be", "to", "of", "and", "a", "in", "that", "have", "I", "it", "for", "not",
278+
"on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
279+
"they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
280+
"there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
281+
"go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
282+
"people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
283+
"than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
284+
"after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
285+
"want", "because", "any", "these", "give", "day", "most", "us",
286+
];
287+
288+
for s in common {
289+
strings.insert(s);
290+
}
291+
292+
let mut partition = PartitionRefinement::new(common.len());
293+
294+
let mut classes = vec![ClassId::MIN_ID];
295+
let mut new_classes = vec![];
296+
297+
let mut offset = 0;
298+
299+
while !classes.is_empty() {
300+
for class in classes.drain(..) {
301+
partition.multiway_unstable_sort_and_split_by(
302+
class,
303+
|a, b| {
304+
strings[a]
305+
.as_bytes()
306+
.get(offset)
307+
.cmp(&strings[b].as_bytes().get(offset))
308+
},
309+
|partition, new_class| {
310+
if partition.class_len(new_class) >= 2 {
311+
new_classes.push(new_class);
312+
}
313+
},
314+
);
315+
if partition.class_len(class) >= 2 {
316+
new_classes.push(class);
317+
}
318+
}
319+
320+
offset += 1;
321+
322+
swap(&mut classes, &mut new_classes);
323+
}
324+
325+
for pair in partition.members().windows(2) {
326+
let [a, b] = *pair else { unreachable!() };
327+
assert!(strings[a] < strings[b]);
328+
}
329+
}
330+
}

0 commit comments

Comments
 (0)