Skip to content

Commit 901b630

Browse files
authored
ReferenceSampleTree Random Access (#983)
* Adding `operator[]` to `ReferenceSampleTree`. * Allows quick access to any reference sample bit within the tree at a given absolute index. * Updating `src/stim/util_top/reference_sample_tree.test.cc` to check bit positions, as applicable.
1 parent e59ffcc commit 901b630

File tree

3 files changed

+144
-11
lines changed

3 files changed

+144
-11
lines changed

src/stim/util_top/reference_sample_tree.cc

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#include "stim/util_top/reference_sample_tree.h"
22

3+
#if defined(_WIN32)
4+
#include <intrin.h>
5+
#pragma intrinsic(_umul128)
6+
#endif
7+
38
using namespace stim;
49

510
bool ReferenceSampleTree::empty() const {
@@ -149,12 +154,23 @@ ReferenceSampleTree ReferenceSampleTree::simplified() const {
149154
return result;
150155
}
151156

152-
size_t ReferenceSampleTree::size() const {
153-
size_t result = prefix_bits.size();
157+
uint64_t ReferenceSampleTree::size() const {
158+
uint64_t result = prefix_bits.size();
154159
for (const auto &child : suffix_children) {
155160
result += child.size();
156161
}
157-
return result * repetitions;
162+
#if defined(__GNUC__) || defined(__clang__)
163+
bool overflow = __builtin_mul_overflow(result, repetitions, &result);
164+
assert(!overflow);
165+
#elif defined(_WIN64)
166+
uint64_t overflow;
167+
result = _umul128(result, repetitions, &overflow);
168+
assert(overflow == 0);
169+
#else
170+
assert((repetitions == 0) || (result <= (UINT64_MAX / repetitions)));
171+
result *= repetitions;
172+
#endif
173+
return result;
158174
}
159175

160176
void ReferenceSampleTree::decompress_into(std::vector<bool> &output) const {
@@ -189,6 +205,57 @@ bool ReferenceSampleTree::operator!=(const ReferenceSampleTree &other) const {
189205
return !(*this == other);
190206
}
191207

208+
bool ReferenceSampleTree::operator[](uint64_t index) const {
209+
uint64_t current_absolute_index = 0;
210+
bool result;
211+
bool value_found = try_get_bit_value(index, current_absolute_index, result);
212+
assert(value_found);
213+
return result;
214+
}
215+
216+
bool ReferenceSampleTree::try_get_bit_value(
217+
uint64_t desired_absolute_index, uint64_t &current_absolute_index, bool &bit_value) const {
218+
// Run through once to allow shallow accesses (without unnecessary full iteration of the tree).
219+
const uint64_t current_relative_starting_index = current_absolute_index;
220+
{
221+
const uint64_t relative_index = desired_absolute_index - current_absolute_index;
222+
if (relative_index < prefix_bits.size()) {
223+
bit_value = prefix_bits[relative_index];
224+
return true;
225+
}
226+
current_absolute_index += prefix_bits.size();
227+
for (const ReferenceSampleTree &child : suffix_children) {
228+
if (child.try_get_bit_value(desired_absolute_index, current_absolute_index, bit_value)) {
229+
return true;
230+
}
231+
}
232+
}
233+
// After the first full iteration, extrapolate the repetition size and skip to the proper iteration.
234+
const uint64_t single_iteration_size = current_absolute_index - current_relative_starting_index;
235+
const uint64_t skip_to_iteration_count =
236+
(desired_absolute_index - current_relative_starting_index) / single_iteration_size;
237+
if (skip_to_iteration_count < repetitions) {
238+
// If the desired index is in this part of the tree, skip forward to the appropriate iteration.
239+
current_absolute_index += single_iteration_size * (skip_to_iteration_count - 1);
240+
// Do the final iteration to find the value.
241+
const uint64_t relative_index = desired_absolute_index - current_absolute_index;
242+
if (relative_index < prefix_bits.size()) {
243+
bit_value = prefix_bits[relative_index];
244+
return true;
245+
}
246+
current_absolute_index += prefix_bits.size();
247+
for (const ReferenceSampleTree &child : suffix_children) {
248+
if (child.try_get_bit_value(desired_absolute_index, current_absolute_index, bit_value)) {
249+
return true;
250+
}
251+
}
252+
} else {
253+
// Advance past this node for parent to continue.
254+
current_absolute_index += single_iteration_size * (repetitions - 1);
255+
}
256+
return false;
257+
}
258+
192259
std::ostream &stim::operator<<(std::ostream &out, const ReferenceSampleTree &v) {
193260
out << v.repetitions << "*";
194261
out << "(";

src/stim/util_top/reference_sample_tree.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ struct ReferenceSampleTree {
2424
bool operator==(const ReferenceSampleTree &other) const;
2525
/// Checks if two trees are not exactly the same, including structure (not just uncompressed contents).
2626
bool operator!=(const ReferenceSampleTree &other) const;
27+
/// Returns the bit value for a given absolute index.
28+
bool operator[](uint64_t index) const;
2729
/// Returns a simple description of the tree's structure, like "5*('101'+6*('11'))".
2830
std::string str() const;
2931

3032
/// Determines whether the tree contains any bits at all.
3133
bool empty() const;
3234
/// Computes the total size of the uncompressed bits represented by the tree.
33-
size_t size() const;
35+
uint64_t size() const;
3436

3537
/// Writes the contents of the tree into the given output vector.
3638
void decompress_into(std::vector<bool> &output) const;
@@ -45,6 +47,8 @@ struct ReferenceSampleTree {
4547
private:
4648
/// Helper method for `simplified`.
4749
void flatten_and_simplify_into(std::vector<ReferenceSampleTree> &out) const;
50+
/// Helper method for `operator[]`.
51+
bool try_get_bit_value(uint64_t desired_absolute_index, uint64_t & current_absolute_index, bool & bit_value) const;
4852
};
4953
std::ostream &operator<<(std::ostream &out, const ReferenceSampleTree &v);
5054

src/stim/util_top/reference_sample_tree.test.cc

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ void expect_tree_matches_normal_reference_sample_of(const ReferenceSampleTree &t
1515
}
1616
auto expected = TableauSimulator<MAX_BITWORD_WIDTH>::reference_sample_circuit(circuit);
1717
EXPECT_EQ(actual, expected);
18+
for (size_t index = 0; index < decompressed.size(); ++index) {
19+
ASSERT_EQ(tree[index], decompressed[index]) << "index: " << index;
20+
}
1821
}
1922

2023
TEST(ReferenceSampleTree, equality) {
@@ -102,20 +105,24 @@ TEST(ReferenceSampleTree, simplified) {
102105

103106
TEST(ReferenceSampleTree, decompress_into) {
104107
std::vector<bool> result;
105-
ReferenceSampleTree{
108+
ReferenceSampleTree tree_under_test{
106109
.prefix_bits = {1, 1, 0, 1},
107110
.suffix_children = {ReferenceSampleTree{
108111
.prefix_bits = {1},
109112
.suffix_children = {},
110113
.repetitions = 5,
111114
}},
112115
.repetitions = 2,
116+
};
117+
tree_under_test.decompress_into(result);
118+
std::vector<bool> expected = std::vector<bool>{1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1};
119+
ASSERT_EQ(result, expected);
120+
for (size_t index = 0; index < expected.size(); ++index) {
121+
ASSERT_EQ(tree_under_test[index], expected[index]) << "index: " << index;
113122
}
114-
.decompress_into(result);
115-
ASSERT_EQ(result, (std::vector<bool>{1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1}));
116123

117124
result.clear();
118-
ReferenceSampleTree{
125+
tree_under_test = ReferenceSampleTree{
119126
.prefix_bits = {1, 1, 0, 1},
120127
.suffix_children =
121128
{
@@ -131,10 +138,14 @@ TEST(ReferenceSampleTree, decompress_into) {
131138
},
132139
},
133140
.repetitions = 1,
141+
};
142+
tree_under_test.decompress_into(result);
143+
expected =
144+
std::vector<bool>{1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0};
145+
ASSERT_EQ(result, expected);
146+
for (size_t index = 0; index < expected.size(); ++index) {
147+
ASSERT_EQ(tree_under_test[index], expected[index]) << "index: " << index;
134148
}
135-
.decompress_into(result);
136-
ASSERT_EQ(result, (std::vector<bool>{1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
137-
1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0}));
138149
}
139150

140151
TEST(ReferenceSampleTree, simple_circuit) {
@@ -292,3 +303,54 @@ TEST(ReferenceSampleTree, surface_code_with_pauli_vs_normal_reference_sample) {
292303
ASSERT_EQ(ref.size(), circuit.count_measurements());
293304
expect_tree_matches_normal_reference_sample_of(ref, circuit);
294305
}
306+
307+
TEST(ReferenceSampleTree, random_access_large_tree) {
308+
ReferenceSampleTree tree_under_test{
309+
.prefix_bits = {1, 1, 0, 1},
310+
.suffix_children =
311+
{ReferenceSampleTree{
312+
.prefix_bits = {1, 0, 1},
313+
.suffix_children = {},
314+
.repetitions = 60'000'000,
315+
},
316+
ReferenceSampleTree{
317+
.prefix_bits = {0, 0, 0, 0, 0, 0, 1},
318+
.suffix_children = {ReferenceSampleTree{
319+
.prefix_bits = {1, 1, 1, 0, 0, 1},
320+
.suffix_children = {},
321+
.repetitions = 42,
322+
}},
323+
.repetitions = 2'000'000'000,
324+
},
325+
ReferenceSampleTree{
326+
.prefix_bits = {0, 0, 0, 0, 1},
327+
.suffix_children = {},
328+
.repetitions = 999'000'000,
329+
}},
330+
.repetitions = 1'234'000,
331+
};
332+
333+
std::vector<bool> expected_beginning = std::vector<bool>{1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1};
334+
for (size_t index = 0; index < expected_beginning.size(); ++index) {
335+
ASSERT_EQ(tree_under_test[index], expected_beginning[index]) << "index: " << index;
336+
}
337+
338+
uint64_t whole_tree_size = tree_under_test.size();
339+
ASSERT_EQ(
340+
whole_tree_size,
341+
1'234'000ULL * (4 + (60'000'000ULL * 3) + (2'000'000'000ULL * (7 + (42 * 6))) + (999'000'000ULL * 5)));
342+
343+
std::vector<bool> expected_ending =
344+
std::vector<bool>{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1};
345+
for (uint64_t index = 0; index < expected_ending.size(); ++index) {
346+
ASSERT_EQ(tree_under_test[whole_tree_size - expected_ending.size() + index], expected_ending[index])
347+
<< "index: " << index;
348+
}
349+
// Same thing on previous outer iteration.
350+
uint64_t one_outer_iter_before_ending = whole_tree_size - (whole_tree_size / 1'234'000ULL);
351+
for (uint64_t index = 0; index < expected_ending.size(); ++index) {
352+
ASSERT_EQ(
353+
tree_under_test[one_outer_iter_before_ending - expected_ending.size() + index], expected_ending[index])
354+
<< "index: " << index;
355+
}
356+
}

0 commit comments

Comments
 (0)