Skip to content

Commit cbb807d

Browse files
committed
(ml5717) Untested impl of more robust index sampling
1 parent a2aef7a commit cbb807d

File tree

3 files changed

+104
-36
lines changed
  • necsim
    • core/src/cogs
    • impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic

3 files changed

+104
-36
lines changed

necsim/core/src/cogs/rng.rs

+96-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use core::{
22
convert::AsMut,
33
default::Default,
4-
num::{NonZeroU128, NonZeroU32, NonZeroUsize},
4+
num::{NonZeroU128, NonZeroU32, NonZeroU64, NonZeroUsize},
55
ptr::copy_nonoverlapping,
66
};
77

@@ -41,6 +41,7 @@ pub trait SeedableRng<M: MathsCore>: RngCore<M> {
4141
const INC: u64 = 11_634_580_027_462_260_723_u64;
4242

4343
let mut seed = Self::Seed::default();
44+
4445
for chunk in seed.as_mut().chunks_mut(4) {
4546
// We advance the state first (to get away from the input value,
4647
// in case it has low Hamming Weight).
@@ -96,51 +97,112 @@ pub trait RngSampler<M: MathsCore>: RngCore<M> {
9697
#[inline]
9798
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
9899
fn sample_index(&mut self, length: NonZeroUsize) -> usize {
99-
// attributes on expressions are experimental
100-
// see https://github.com/rust-lang/rust/issues/15701
101-
#[allow(
102-
clippy::cast_precision_loss,
103-
clippy::cast_possible_truncation,
104-
clippy::cast_sign_loss
105-
)]
106-
let index =
107-
M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as usize;
108-
// Safety in case of f64 rounding errors
109-
index.min(length.get() - 1)
100+
#[cfg(target_pointer_width = "32")]
101+
#[allow(clippy::cast_possible_truncation)]
102+
{
103+
self.sample_index_u32(unsafe { NonZeroU32::new_unchecked(length.get() as u32) })
104+
as usize
105+
}
106+
#[cfg(target_pointer_width = "64")]
107+
#[allow(clippy::cast_possible_truncation)]
108+
{
109+
self.sample_index_u64(unsafe { NonZeroU64::new_unchecked(length.get() as u64) })
110+
as usize
111+
}
110112
}
111113

112114
#[must_use]
113115
#[inline]
114116
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
115117
fn sample_index_u32(&mut self, length: NonZeroU32) -> u32 {
116-
// attributes on expressions are experimental
117-
// see https://github.com/rust-lang/rust/issues/15701
118-
#[allow(
119-
clippy::cast_precision_loss,
120-
clippy::cast_possible_truncation,
121-
clippy::cast_sign_loss
122-
)]
123-
let index =
124-
M::floor(self.sample_uniform_closed_open().get() * f64::from(length.get())) as u32;
125-
// Safety in case of f64 rounding errors
126-
index.min(length.get() - 1)
118+
// TODO: Check if delegation to `sample_index_u64` is faster
119+
120+
// Adapted from:
121+
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single
122+
123+
const LOWER_MASK: u64 = !0 >> 32;
124+
125+
// Conservative approximation of the acceptance zone
126+
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);
127+
128+
loop {
129+
let raw = self.sample_u64();
130+
131+
let sample_check_lo = (raw & LOWER_MASK) * u64::from(length.get());
132+
133+
#[allow(clippy::cast_possible_truncation)]
134+
if (sample_check_lo as u32) <= acceptance_zone {
135+
return (sample_check_lo >> 32) as u32;
136+
}
137+
138+
let sample_check_hi = (raw >> 32) * u64::from(length.get());
139+
140+
#[allow(clippy::cast_possible_truncation)]
141+
if (sample_check_hi as u32) <= acceptance_zone {
142+
return (sample_check_hi >> 32) as u32;
143+
}
144+
}
145+
}
146+
147+
#[must_use]
148+
#[inline]
149+
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
150+
fn sample_index_u64(&mut self, length: NonZeroU64) -> u64 {
151+
// Adapted from:
152+
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single
153+
154+
// Conservative approximation of the acceptance zone
155+
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);
156+
157+
loop {
158+
let raw = self.sample_u64();
159+
160+
let sample_check = u128::from(raw) * u128::from(length.get());
161+
162+
#[allow(clippy::cast_possible_truncation)]
163+
if (sample_check as u64) <= acceptance_zone {
164+
return (sample_check >> 64) as u64;
165+
}
166+
}
127167
}
128168

129169
#[must_use]
130170
#[inline]
131171
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
132172
fn sample_index_u128(&mut self, length: NonZeroU128) -> u128 {
133-
// attributes on expressions are experimental
134-
// see https://github.com/rust-lang/rust/issues/15701
135-
#[allow(
136-
clippy::cast_precision_loss,
137-
clippy::cast_possible_truncation,
138-
clippy::cast_sign_loss
139-
)]
140-
let index =
141-
M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as u128;
142-
// Safety in case of f64 rounding errors
143-
index.min(length.get() - 1)
173+
// Adapted from:
174+
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single
175+
176+
const LOWER_MASK: u128 = !0 >> 64;
177+
178+
// Conservative approximation of the acceptance zone
179+
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);
180+
181+
loop {
182+
let raw_hi = u128::from(self.sample_u64());
183+
let raw_lo = u128::from(self.sample_u64());
184+
185+
// 256-bit multiplication (hi, lo) = (raw_hi, raw_lo) * length
186+
let mut low = raw_lo * (length.get() & LOWER_MASK);
187+
let mut t = low >> 64;
188+
low &= LOWER_MASK;
189+
t += raw_hi * (length.get() & LOWER_MASK);
190+
low += (t & LOWER_MASK) << 64;
191+
let mut high = t >> 64;
192+
t = low >> 64;
193+
low &= LOWER_MASK;
194+
t += (length.get() >> 64) * raw_lo;
195+
low += (t & LOWER_MASK) << 64;
196+
high += t >> 64;
197+
high += raw_hi * (length.get() >> 64);
198+
199+
let sample = high;
200+
let check = low;
201+
202+
if check <= acceptance_zone {
203+
return sample;
204+
}
205+
}
144206
}
145207

146208
#[must_use]

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/indexed/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use alloc::{vec, vec::Vec};
22
use core::{
33
cmp::Ordering,
4+
convert::TryFrom,
45
fmt,
56
hash::Hash,
6-
num::{NonZeroU128, NonZeroUsize},
7+
num::{NonZeroU128, NonZeroU64, NonZeroUsize},
78
};
89
use fnv::FnvBuildHasher;
910

@@ -191,6 +192,8 @@ impl<E: Eq + Hash + Clone> DynamicAliasMethodIndexedSampler<E> {
191192
if let Some(total_weight) = NonZeroU128::new(self.total_weight) {
192193
let cdf_sample = if let [_group] = &self.groups[..] {
193194
0_u128
195+
} else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) {
196+
u128::from(rng.sample_index_u64(total_weight))
194197
} else {
195198
rng.sample_index_u128(total_weight)
196199
};

necsim/impls/no-std/src/cogs/active_lineage_sampler/alias/dynamic/stack/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use alloc::{vec, vec::Vec};
22
use core::{
33
cmp::Ordering,
4+
convert::TryFrom,
45
fmt,
56
hash::Hash,
6-
num::{NonZeroU128, NonZeroUsize},
7+
num::{NonZeroU128, NonZeroU64, NonZeroUsize},
78
};
89

910
use necsim_core::cogs::{MathsCore, RngCore, RngSampler};
@@ -125,6 +126,8 @@ impl<E: Eq + Hash + Clone> DynamicAliasMethodStackSampler<E> {
125126
if let Some(total_weight) = NonZeroU128::new(self.total_weight) {
126127
let cdf_sample = if let [_group] = &self.groups[..] {
127128
0_u128
129+
} else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) {
130+
u128::from(rng.sample_index_u64(total_weight))
128131
} else {
129132
rng.sample_index_u128(total_weight)
130133
};

0 commit comments

Comments
 (0)