|
1 | 1 | use crate::{
|
2 | 2 | serialization::{deserialize_long_array, serialize_long_array},
|
3 |
| - utils::{less_than_or_equal_to_unsafe, range_check_optimized, Endianness, PackerTarget}, |
| 3 | + utils::{ |
| 4 | + less_than_or_equal_to_unsafe, less_than_unsafe, range_check_optimized, Endianness, |
| 5 | + PackerTarget, |
| 6 | + }, |
4 | 7 | };
|
5 | 8 | use anyhow::{anyhow, Result};
|
6 | 9 | use plonky2::{
|
@@ -600,6 +603,91 @@ where
|
600 | 603 | pub fn last(&self) -> T {
|
601 | 604 | self.arr[SIZE - 1]
|
602 | 605 | }
|
| 606 | + |
| 607 | + /// This function allows you to search a larger [`Array`] by representing it as a number of |
| 608 | + /// smaller [`Array`]s with size [`RANDOM_ACCESS_SIZE`], padding the final smaller array where required. |
| 609 | + pub fn random_access_large_array<F: RichField + Extendable<D>, const D: usize>( |
| 610 | + &self, |
| 611 | + b: &mut CircuitBuilder<F, D>, |
| 612 | + at: Target, |
| 613 | + ) -> T { |
| 614 | + // We will split the array into smaller arrays of size 64, padding the last array with zeroes if required |
| 615 | + let padded_size = (SIZE - 1) / RANDOM_ACCESS_SIZE + 1; |
| 616 | + |
| 617 | + // Create an array of `Array`s |
| 618 | + let arrays: Vec<Array<T, RANDOM_ACCESS_SIZE>> = (0..padded_size) |
| 619 | + .map(|i| Array { |
| 620 | + arr: create_array(|j| { |
| 621 | + let index = 64 * i + j; |
| 622 | + if index < self.arr.len() { |
| 623 | + self.arr[index] |
| 624 | + } else { |
| 625 | + T::from_target(b.zero()) |
| 626 | + } |
| 627 | + }), |
| 628 | + }) |
| 629 | + .collect(); |
| 630 | + |
| 631 | + // We need to express `at` in base 64, we are also assuming that the initial array was smaller than 64^2 = 4096 which we enforce with a range check. |
| 632 | + // We also check that `at` is smaller that the size of the array. |
| 633 | + let array_size = b.constant(F::from_noncanonical_u64(SIZE as u64)); |
| 634 | + let less_than_check = less_than_unsafe(b, at, array_size, 12); |
| 635 | + let true_target = b._true(); |
| 636 | + b.connect(less_than_check.target, true_target.target); |
| 637 | + b.range_check(at, 12); |
| 638 | + let (low_bits, high_bits) = b.split_low_high(at, 6, 12); |
| 639 | + |
| 640 | + // Search each of the smaller arrays for the target at `low_bits` |
| 641 | + let first_search = arrays |
| 642 | + .into_iter() |
| 643 | + .map(|array| { |
| 644 | + b.random_access( |
| 645 | + low_bits, |
| 646 | + array |
| 647 | + .arr |
| 648 | + .iter() |
| 649 | + .map(Targetable::to_target) |
| 650 | + .collect::<Vec<Target>>(), |
| 651 | + ) |
| 652 | + }) |
| 653 | + .collect::<Vec<Target>>(); |
| 654 | + |
| 655 | + // Serach the result for the Target at `high_bits` |
| 656 | + T::from_target(b.random_access(high_bits, first_search)) |
| 657 | + } |
| 658 | + |
| 659 | + /// Returns [`self[at..at+SUB_SIZE]`]. |
| 660 | + /// This is more expensive than [`Self::extract_array`] due to using [`Self::random_access_large_array`] |
| 661 | + /// instead of [`Self::value_at`]. This function enforces that the values extracted are within the array. |
| 662 | + pub fn extract_array_large< |
| 663 | + F: RichField + Extendable<D>, |
| 664 | + const D: usize, |
| 665 | + const SUB_SIZE: usize, |
| 666 | + >( |
| 667 | + &self, |
| 668 | + b: &mut CircuitBuilder<F, D>, |
| 669 | + at: Target, |
| 670 | + ) -> Array<T, SUB_SIZE> { |
| 671 | + let m = b.constant(F::from_canonical_usize(SUB_SIZE)); |
| 672 | + let array_len = b.constant(F::from_canonical_usize(SIZE)); |
| 673 | + let upper_bound = b.add(at, m); |
| 674 | + let num_bits_size = SIZE.ilog2() + 1; |
| 675 | + |
| 676 | + let lt = less_than_or_equal_to_unsafe(b, upper_bound, array_len, num_bits_size as usize); |
| 677 | + |
| 678 | + let t = b._true(); |
| 679 | + b.connect(t.target, lt.target); |
| 680 | + |
| 681 | + Array::<T, SUB_SIZE> { |
| 682 | + arr: core::array::from_fn(|i| { |
| 683 | + let i_target = b.constant(F::from_canonical_usize(i)); |
| 684 | + let i_plus_n_target = b.add(at, i_target); |
| 685 | + |
| 686 | + // out_val = arr[((i+n)<=n+M) * (i+n)] |
| 687 | + self.random_access_large_array(b, i_plus_n_target) |
| 688 | + }), |
| 689 | + } |
| 690 | + } |
603 | 691 | }
|
604 | 692 | /// Returns the size of the array in 32-bit units, rounded up.
|
605 | 693 | #[allow(non_snake_case)]
|
@@ -815,6 +903,51 @@ mod test {
|
815 | 903 | run_circuit::<F, D, C, _>(ValueAtCircuit { arr, idx, exp });
|
816 | 904 | }
|
817 | 905 |
|
| 906 | + #[test] |
| 907 | + fn test_random_access_large_array() { |
| 908 | + const SIZE: usize = 512; |
| 909 | + #[derive(Clone, Debug)] |
| 910 | + struct ValueAtCircuit { |
| 911 | + arr: [u8; SIZE], |
| 912 | + idx: usize, |
| 913 | + exp: u8, |
| 914 | + } |
| 915 | + impl<F, const D: usize> UserCircuit<F, D> for ValueAtCircuit |
| 916 | + where |
| 917 | + F: RichField + Extendable<D>, |
| 918 | + { |
| 919 | + type Wires = (Array<Target, SIZE>, Target, Target); |
| 920 | + fn build(c: &mut CircuitBuilder<F, D>) -> Self::Wires { |
| 921 | + let array = Array::<Target, SIZE>::new(c); |
| 922 | + let exp_value = c.add_virtual_target(); |
| 923 | + let index = c.add_virtual_target(); |
| 924 | + let extracted = array.random_access_large_array(c, index); |
| 925 | + c.connect(exp_value, extracted); |
| 926 | + (array, index, exp_value) |
| 927 | + } |
| 928 | + fn prove(&self, pw: &mut PartialWitness<F>, wires: &Self::Wires) { |
| 929 | + wires |
| 930 | + .0 |
| 931 | + .assign(pw, &create_array(|i| F::from_canonical_u8(self.arr[i]))); |
| 932 | + pw.set_target(wires.1, F::from_canonical_usize(self.idx)); |
| 933 | + pw.set_target(wires.2, F::from_canonical_u8(self.exp)); |
| 934 | + } |
| 935 | + } |
| 936 | + let mut rng = thread_rng(); |
| 937 | + let mut arr = [0u8; SIZE]; |
| 938 | + rng.fill(&mut arr[..]); |
| 939 | + let idx: usize = rng.gen_range(0..SIZE); |
| 940 | + let exp = arr[idx]; |
| 941 | + run_circuit::<F, D, C, _>(ValueAtCircuit { arr, idx, exp }); |
| 942 | + |
| 943 | + // Now we check that it fails when the index is too large |
| 944 | + let idx = SIZE; |
| 945 | + let result = std::panic::catch_unwind(|| { |
| 946 | + run_circuit::<F, D, C, _>(ValueAtCircuit { arr, idx, exp }) |
| 947 | + }); |
| 948 | + assert!(result.is_err()); |
| 949 | + } |
| 950 | + |
818 | 951 | #[test]
|
819 | 952 | fn test_extract_array() {
|
820 | 953 | const SIZE: usize = 80;
|
@@ -858,6 +991,56 @@ mod test {
|
858 | 991 | run_circuit::<F, D, C, _>(ExtractArrayCircuit { arr, idx, exp });
|
859 | 992 | }
|
860 | 993 |
|
| 994 | + #[test] |
| 995 | + fn test_extract_array_large() { |
| 996 | + const SIZE: usize = 512; |
| 997 | + const SUBSIZE: usize = 40; |
| 998 | + #[derive(Clone, Debug)] |
| 999 | + struct ExtractArrayCircuit { |
| 1000 | + arr: [u8; SIZE], |
| 1001 | + idx: usize, |
| 1002 | + exp: [u8; SUBSIZE], |
| 1003 | + } |
| 1004 | + impl<F, const D: usize> UserCircuit<F, D> for ExtractArrayCircuit |
| 1005 | + where |
| 1006 | + F: RichField + Extendable<D>, |
| 1007 | + { |
| 1008 | + type Wires = (Array<Target, SIZE>, Target, Array<Target, SUBSIZE>); |
| 1009 | + fn build(c: &mut CircuitBuilder<F, D>) -> Self::Wires { |
| 1010 | + let array = Array::<Target, SIZE>::new(c); |
| 1011 | + let index = c.add_virtual_target(); |
| 1012 | + let expected = Array::<Target, SUBSIZE>::new(c); |
| 1013 | + let extracted = array.extract_array_large::<_, _, SUBSIZE>(c, index); |
| 1014 | + let are_equal = expected.equals(c, &extracted); |
| 1015 | + let tru = c._true(); |
| 1016 | + c.connect(are_equal.target, tru.target); |
| 1017 | + (array, index, expected) |
| 1018 | + } |
| 1019 | + fn prove(&self, pw: &mut PartialWitness<F>, wires: &Self::Wires) { |
| 1020 | + wires |
| 1021 | + .0 |
| 1022 | + .assign(pw, &create_array(|i| F::from_canonical_u8(self.arr[i]))); |
| 1023 | + pw.set_target(wires.1, F::from_canonical_usize(self.idx)); |
| 1024 | + wires |
| 1025 | + .2 |
| 1026 | + .assign(pw, &create_array(|i| F::from_canonical_u8(self.exp[i]))); |
| 1027 | + } |
| 1028 | + } |
| 1029 | + let mut rng = thread_rng(); |
| 1030 | + let mut arr = [0u8; SIZE]; |
| 1031 | + rng.fill(&mut arr[..]); |
| 1032 | + let idx: usize = rng.gen_range(0..(SIZE - SUBSIZE)); |
| 1033 | + let exp = create_array(|i| arr[idx + i]); |
| 1034 | + run_circuit::<F, D, C, _>(ExtractArrayCircuit { arr, idx, exp }); |
| 1035 | + |
| 1036 | + // It should panic if we try to extract an array where some of the indices fall outside of (0..SIZE) |
| 1037 | + let idx = SIZE; |
| 1038 | + let result = std::panic::catch_unwind(|| { |
| 1039 | + run_circuit::<F, D, C, _>(ExtractArrayCircuit { arr, idx, exp }) |
| 1040 | + }); |
| 1041 | + assert!(result.is_err()); |
| 1042 | + } |
| 1043 | + |
861 | 1044 | #[test]
|
862 | 1045 | fn test_contains_subarray() {
|
863 | 1046 | #[derive(Clone, Debug)]
|
|
0 commit comments