|
1 | 1 | use core::{
|
2 | 2 | convert::AsMut,
|
3 | 3 | default::Default,
|
4 |
| - num::{NonZeroU128, NonZeroU32, NonZeroUsize}, |
| 4 | + num::{NonZeroU128, NonZeroU32, NonZeroU64, NonZeroUsize}, |
5 | 5 | ptr::copy_nonoverlapping,
|
6 | 6 | };
|
7 | 7 |
|
@@ -41,6 +41,7 @@ pub trait SeedableRng<M: MathsCore>: RngCore<M> {
|
41 | 41 | const INC: u64 = 11_634_580_027_462_260_723_u64;
|
42 | 42 |
|
43 | 43 | let mut seed = Self::Seed::default();
|
| 44 | + |
44 | 45 | for chunk in seed.as_mut().chunks_mut(4) {
|
45 | 46 | // We advance the state first (to get away from the input value,
|
46 | 47 | // in case it has low Hamming Weight).
|
@@ -96,51 +97,112 @@ pub trait RngSampler<M: MathsCore>: RngCore<M> {
|
96 | 97 | #[inline]
|
97 | 98 | #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
|
98 | 99 | fn sample_index(&mut self, length: NonZeroUsize) -> usize {
|
99 |
| - // attributes on expressions are experimental |
100 |
| - // see https://github.com/rust-lang/rust/issues/15701 |
101 |
| - #[allow( |
102 |
| - clippy::cast_precision_loss, |
103 |
| - clippy::cast_possible_truncation, |
104 |
| - clippy::cast_sign_loss |
105 |
| - )] |
106 |
| - let index = |
107 |
| - M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as usize; |
108 |
| - // Safety in case of f64 rounding errors |
109 |
| - index.min(length.get() - 1) |
| 100 | + #[cfg(target_pointer_width = "32")] |
| 101 | + #[allow(clippy::cast_possible_truncation)] |
| 102 | + { |
| 103 | + self.sample_index_u32(unsafe { NonZeroU32::new_unchecked(length.get() as u32) }) |
| 104 | + as usize |
| 105 | + } |
| 106 | + #[cfg(target_pointer_width = "64")] |
| 107 | + #[allow(clippy::cast_possible_truncation)] |
| 108 | + { |
| 109 | + self.sample_index_u64(unsafe { NonZeroU64::new_unchecked(length.get() as u64) }) |
| 110 | + as usize |
| 111 | + } |
110 | 112 | }
|
111 | 113 |
|
112 | 114 | #[must_use]
|
113 | 115 | #[inline]
|
114 | 116 | #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
|
115 | 117 | fn sample_index_u32(&mut self, length: NonZeroU32) -> u32 {
|
116 |
| - // attributes on expressions are experimental |
117 |
| - // see https://github.com/rust-lang/rust/issues/15701 |
118 |
| - #[allow( |
119 |
| - clippy::cast_precision_loss, |
120 |
| - clippy::cast_possible_truncation, |
121 |
| - clippy::cast_sign_loss |
122 |
| - )] |
123 |
| - let index = |
124 |
| - M::floor(self.sample_uniform_closed_open().get() * f64::from(length.get())) as u32; |
125 |
| - // Safety in case of f64 rounding errors |
126 |
| - index.min(length.get() - 1) |
| 118 | + // TODO: Check if delegation to `sample_index_u64` is faster |
| 119 | + |
| 120 | + // Adapted from: |
| 121 | + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single |
| 122 | + |
| 123 | + const LOWER_MASK: u64 = !0 >> 32; |
| 124 | + |
| 125 | + // Conservative approximation of the acceptance zone |
| 126 | + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); |
| 127 | + |
| 128 | + loop { |
| 129 | + let raw = self.sample_u64(); |
| 130 | + |
| 131 | + let sample_check_lo = (raw & LOWER_MASK) * u64::from(length.get()); |
| 132 | + |
| 133 | + #[allow(clippy::cast_possible_truncation)] |
| 134 | + if (sample_check_lo as u32) <= acceptance_zone { |
| 135 | + return (sample_check_lo >> 32) as u32; |
| 136 | + } |
| 137 | + |
| 138 | + let sample_check_hi = (raw >> 32) * u64::from(length.get()); |
| 139 | + |
| 140 | + #[allow(clippy::cast_possible_truncation)] |
| 141 | + if (sample_check_hi as u32) <= acceptance_zone { |
| 142 | + return (sample_check_hi >> 32) as u32; |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + #[must_use] |
| 148 | + #[inline] |
| 149 | + #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")] |
| 150 | + fn sample_index_u64(&mut self, length: NonZeroU64) -> u64 { |
| 151 | + // Adapted from: |
| 152 | + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single |
| 153 | + |
| 154 | + // Conservative approximation of the acceptance zone |
| 155 | + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); |
| 156 | + |
| 157 | + loop { |
| 158 | + let raw = self.sample_u64(); |
| 159 | + |
| 160 | + let sample_check = u128::from(raw) * u128::from(length.get()); |
| 161 | + |
| 162 | + #[allow(clippy::cast_possible_truncation)] |
| 163 | + if (sample_check as u64) <= acceptance_zone { |
| 164 | + return (sample_check >> 64) as u64; |
| 165 | + } |
| 166 | + } |
127 | 167 | }
|
128 | 168 |
|
129 | 169 | #[must_use]
|
130 | 170 | #[inline]
|
131 | 171 | #[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
|
132 | 172 | fn sample_index_u128(&mut self, length: NonZeroU128) -> u128 {
|
133 |
| - // attributes on expressions are experimental |
134 |
| - // see https://github.com/rust-lang/rust/issues/15701 |
135 |
| - #[allow( |
136 |
| - clippy::cast_precision_loss, |
137 |
| - clippy::cast_possible_truncation, |
138 |
| - clippy::cast_sign_loss |
139 |
| - )] |
140 |
| - let index = |
141 |
| - M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as u128; |
142 |
| - // Safety in case of f64 rounding errors |
143 |
| - index.min(length.get() - 1) |
| 173 | + // Adapted from: |
| 174 | + // https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single |
| 175 | + |
| 176 | + const LOWER_MASK: u128 = !0 >> 64; |
| 177 | + |
| 178 | + // Conservative approximation of the acceptance zone |
| 179 | + let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1); |
| 180 | + |
| 181 | + loop { |
| 182 | + let raw_hi = u128::from(self.sample_u64()); |
| 183 | + let raw_lo = u128::from(self.sample_u64()); |
| 184 | + |
| 185 | + // 256-bit multiplication (hi, lo) = (raw_hi, raw_lo) * length |
| 186 | + let mut low = raw_lo * (length.get() & LOWER_MASK); |
| 187 | + let mut t = low >> 64; |
| 188 | + low &= LOWER_MASK; |
| 189 | + t += raw_hi * (length.get() & LOWER_MASK); |
| 190 | + low += (t & LOWER_MASK) << 64; |
| 191 | + let mut high = t >> 64; |
| 192 | + t = low >> 64; |
| 193 | + low &= LOWER_MASK; |
| 194 | + t += (length.get() >> 64) * raw_lo; |
| 195 | + low += (t & LOWER_MASK) << 64; |
| 196 | + high += t >> 64; |
| 197 | + high += raw_hi * (length.get() >> 64); |
| 198 | + |
| 199 | + let sample = high; |
| 200 | + let check = low; |
| 201 | + |
| 202 | + if check <= acceptance_zone { |
| 203 | + return sample; |
| 204 | + } |
| 205 | + } |
144 | 206 | }
|
145 | 207 |
|
146 | 208 | #[must_use]
|
|
0 commit comments