diff --git a/Cargo.lock b/Cargo.lock index bb19b59b0..5d598284a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,9 +74,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "alloy" @@ -101,11 +101,11 @@ dependencies = [ [[package]] name = "alloy-chains" -version = "0.1.38" +version = "0.1.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "156bfc5dcd52ef9a5f33381701fa03310317e14c65093a9430d3e3557b08dcd3" +checksum = "18c5c520273946ecf715c0010b4e3503d7eba9893cd9ce6b7fff5654c4a3c470" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "num_enum", "strum", ] @@ -117,7 +117,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae09ffd7c29062431dd86061deefe4e3c6f07fa0d674930095f8dcedb0baf02c" dependencies = [ "alloy-eips", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "alloy-serde", "auto_impl", @@ -136,37 +136,37 @@ dependencies = [ "alloy-json-abi", "alloy-network", "alloy-network-primitives", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-provider", "alloy-rpc-types-eth", "alloy-sol-types", "alloy-transport", "futures", "futures-util", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] name = "alloy-core" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8316d83e590f4163b221b8180008f302bda5cf5451202855cdd323e588849c" +checksum = "c3d14d531c99995de71558e8e2206c27d709559ee8e5a0452b965ea82405a013" dependencies = [ "alloy-dyn-abi", "alloy-json-abi", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "alloy-sol-types", ] [[package]] name = "alloy-dyn-abi" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef2364c782a245cf8725ea6dbfca5f530162702b5d685992ea03ce64529136cc" +checksum = "80759b3f57b3b20fa7cd8fef6479930fc95461b58ff8adea6e87e618449c8a1d" dependencies = [ "alloy-json-abi", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-sol-type-parser", "alloy-sol-types", "const-hex", @@ -182,18 +182,18 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0069cf0642457f87a01a014f6dc29d5d893cd4fd8fddf0c3cdfad1bb3ebafc41" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "serde", ] [[package]] name = "alloy-eip7702" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6cee6a35793f3db8a5ffe60e86c695f321d081a567211245f503e8c498fce8" +checksum = "4c986539255fb839d1533c128e190e557e52ff652c9ef62939e233a81dd93f7e" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "derive_more 1.0.0", "serde", @@ -207,7 +207,7 @@ checksum = "5b6aa3961694b30ba53d41006131a2fca3bdab22e4c344e46db2c639e7c2dfdd" dependencies = [ "alloy-eip2930", "alloy-eip7702", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "alloy-serde", "c-kzg", @@ -223,18 +223,18 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e53f7877ded3921d18a0a9556d55bedf84535567198c9edab2aa23106da91855" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-serde", "serde", ] [[package]] name = "alloy-json-abi" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84c506bf264110fa7e90d9924f742f40ef53c6572ea56a0b0bd714a567ed389" +checksum = "ac4b22b3e51cac09fd2adfcc73b55f447b4df669f983c13f7894ec82b607c63f" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-sol-type-parser", "serde", "serde_json", @@ -246,11 +246,11 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3694b7e480728c0b3e228384f223937f14c10caef5a4c766021190fc8f283d35" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-sol-types", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -264,7 +264,7 @@ dependencies = [ "alloy-eips", "alloy-json-rpc", "alloy-network-primitives", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rpc-types-eth", "alloy-serde", "alloy-signer", @@ -274,7 +274,7 @@ dependencies = [ "futures-utils-wasm", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -285,7 +285,7 @@ checksum = "df9f3e281005943944d15ee8491534a1c7b3cbf7a7de26f8c433b842b93eb5f9" dependencies = [ "alloy-consensus", "alloy-eips", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-serde", "serde", ] @@ -297,12 +297,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9805d126f24be459b958973c0569c73e1aadd27d4535eee82b2b6764aa03616" dependencies = [ "alloy-genesis", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "k256", "rand", "serde_json", "tempfile", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", "url", ] @@ -326,9 +326,9 @@ dependencies = [ [[package]] name = "alloy-primitives" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fce5dbd6a4f118eecc4719eaa9c7ffc31c315e6c5ccde3642db927802312425" +checksum = "9db948902dfbae96a73c2fbf1f7abec62af034ab883e4c777c3fd29702bd6e2c" dependencies = [ "alloy-rlp", "bytes", @@ -337,7 +337,7 @@ dependencies = [ "derive_more 1.0.0", "foldhash", "getrandom 0.2.15", - "hashbrown 0.15.0", + "hashbrown 0.15.2", "hex-literal", "indexmap 2.6.0", "itoa", @@ -367,7 +367,7 @@ dependencies = [ "alloy-network", "alloy-network-primitives", "alloy-node-bindings", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rpc-client", "alloy-rpc-types-anvil", "alloy-rpc-types-eth", @@ -384,11 +384,11 @@ dependencies = [ "lru", "parking_lot", "pin-project", - "reqwest 0.12.8", + "reqwest 0.12.9", "schnellru", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tracing", "url", @@ -424,12 +424,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374dbe0dc3abdc2c964f36b3d3edf9cdb3db29d16bda34aa123f03d810bec1dd" dependencies = [ "alloy-json-rpc", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-transport", "alloy-transport-http", "futures", "pin-project", - "reqwest 0.12.8", + "reqwest 0.12.9", "serde", "serde_json", "tokio", @@ -446,7 +446,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c74832aa474b670309c20fffc2a869fa141edab7c79ff7963fad0a08de60bae1" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rpc-types-eth", "alloy-serde", "serde", @@ -458,7 +458,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca97963132f78ddfc60e43a017348e6d52eea983925c23652f5b330e8e02291" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rpc-types-eth", "alloy-serde", "serde", @@ -473,7 +473,7 @@ dependencies = [ "alloy-consensus", "alloy-eips", "alloy-network-primitives", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-rlp", "alloy-serde", "alloy-sol-types", @@ -489,7 +489,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dfa4a7ccf15b2492bb68088692481fd6b2604ccbee1d0d6c44c21427ae4df83" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "serde", "serde_json", ] @@ -500,12 +500,12 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e10aec39d60dc27edcac447302c7803d2371946fb737245320a05b78eb2fafd" dependencies = [ - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "async-trait", "auto_impl", "elliptic-curve", "k256", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -516,19 +516,19 @@ checksum = "d8396f6dff60700bc1d215ee03d86ff56de268af96e2bf833a14d0bafcab9882" dependencies = [ "alloy-consensus", "alloy-network", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-signer", "async-trait", "k256", "rand", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] name = "alloy-sol-macro" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9343289b4a7461ed8bab8618504c995c049c082b70c7332efd7b32125633dc05" +checksum = "3bfd7853b65a2b4f49629ec975fee274faf6dff15ab8894c620943398ef283c0" dependencies = [ "alloy-sol-macro-expander", "alloy-sol-macro-input", @@ -540,9 +540,9 @@ dependencies = [ [[package]] name = "alloy-sol-macro-expander" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4222d70bec485ceccc5d8fd4f2909edd65b5d5e43d4aca0b5dcee65d519ae98f" +checksum = "82ec42f342d9a9261699f8078e57a7a4fda8aaa73c1a212ed3987080e6a9cd13" dependencies = [ "alloy-json-abi", "alloy-sol-macro-input", @@ -559,9 +559,9 @@ dependencies = [ [[package]] name = "alloy-sol-macro-input" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e17f2677369571b976e51ea1430eb41c3690d344fef567b840bfc0b01b6f83a" +checksum = "ed2c50e6a62ee2b4f7ab3c6d0366e5770a21cad426e109c2f40335a1b3aff3df" dependencies = [ "alloy-json-abi", "const-hex", @@ -576,9 +576,9 @@ dependencies = [ [[package]] name = "alloy-sol-type-parser" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa64d80ae58ffaafdff9d5d84f58d03775f66c84433916dc9a64ed16af5755da" +checksum = "ac17c6e89a50fb4a758012e4b409d9a0ba575228e69b539fe37d7a1bd507ca4a" dependencies = [ "serde", "winnow", @@ -586,12 +586,12 @@ dependencies = [ [[package]] name = "alloy-sol-types" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6520d427d4a8eb7aa803d852d7a52ceb0c519e784c292f64bb339e636918cf27" +checksum = "c9dc0fffe397aa17628160e16b89f704098bf3c9d74d5d369ebc239575936de5" dependencies = [ "alloy-json-abi", - "alloy-primitives 0.8.12", + "alloy-primitives 0.8.14", "alloy-sol-macro", "const-hex", "serde", @@ -609,7 +609,7 @@ dependencies = [ "futures-utils-wasm", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tower", "tracing", @@ -625,7 +625,7 @@ checksum = "5dc013132e34eeadaa0add7e74164c1503988bfba8bae885b32e0918ba85a8a6" dependencies = [ "alloy-json-rpc", "alloy-transport", - "reqwest 0.12.8", + "reqwest 0.12.9", "serde_json", "tower", "tracing", @@ -668,9 +668,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -683,43 +683,43 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "ark-ff" @@ -991,9 +991,9 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bb8" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b10cf871f3ff2ce56432fddc2615ac7acc3aa22ca321f8fea800846fbb32f188" +checksum = "d89aabfae550a5c44b43ab941844ffcd2e993cb6900b342debf59e9ea74acdb8" dependencies = [ "async-trait", "futures-util", @@ -1124,9 +1124,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" dependencies = [ "serde", ] @@ -1184,9 +1184,9 @@ dependencies = [ [[package]] name = "cargo-platform" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" +checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea" dependencies = [ "serde", ] @@ -1202,14 +1202,14 @@ dependencies = [ "semver 1.0.23", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] name = "cc" -version = "1.1.30" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" +checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" dependencies = [ "jobserver", "libc", @@ -1247,9 +1247,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -1257,9 +1257,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -1281,9 +1281,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "coins-bip32" @@ -1298,7 +1298,7 @@ dependencies = [ "k256", "serde", "sha2", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -1314,7 +1314,7 @@ dependencies = [ "pbkdf2 0.12.2", "rand", "sha2", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -1334,14 +1334,14 @@ dependencies = [ "serde_derive", "sha2", "sha3", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "colored" @@ -1368,9 +1368,9 @@ dependencies = [ [[package]] name = "const-hex" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0121754e84117e65f9d90648ee6aa4882a6e63110307ab73967a4c5e7e69e586" +checksum = "4b0485bab839b018a8f1723fc5391819fea5f8f0f32288ef8a735fd096b6160c" dependencies = [ "cfg-if", "cpufeatures", @@ -1435,9 +1435,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1506,9 +1506,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -1728,7 +1728,7 @@ dependencies = [ "fuzzy-matcher", "shell-words", "tempfile", - "thiserror 1.0.64", + "thiserror 1.0.69", "zeroize", ] @@ -1795,6 +1795,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "dotenv" version = "0.15.0" @@ -1888,9 +1899,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] @@ -2001,12 +2012,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2027,7 +2038,7 @@ dependencies = [ "serde_json", "sha2", "sha3", - "thiserror 1.0.64", + "thiserror 1.0.69", "uuid 0.8.2", ] @@ -2057,7 +2068,7 @@ dependencies = [ "serde", "serde_json", "sha3", - "thiserror 1.0.64", + "thiserror 1.0.69", "uint", ] @@ -2161,7 +2172,7 @@ dependencies = [ "pin-project", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -2180,7 +2191,7 @@ dependencies = [ "pin-project", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] @@ -2285,7 +2296,7 @@ dependencies = [ "strum", "syn 2.0.89", "tempfile", - "thiserror 1.0.64", + "thiserror 1.0.69", "tiny-keccak", "unicode-xid", ] @@ -2315,7 +2326,7 @@ dependencies = [ "strum", "syn 2.0.89", "tempfile", - "thiserror 1.0.64", + "thiserror 1.0.69", "tiny-keccak", "unicode-xid", ] @@ -2331,7 +2342,7 @@ dependencies = [ "semver 1.0.23", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -2347,7 +2358,7 @@ dependencies = [ "semver 1.0.23", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -2369,7 +2380,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tracing", "tracing-futures", @@ -2396,7 +2407,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tracing", "tracing-futures", @@ -2427,7 +2438,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tokio-tungstenite", "tracing", @@ -2464,7 +2475,7 @@ dependencies = [ "reqwest 0.11.27", "serde", "serde_json", - "thiserror 1.0.64", + "thiserror 1.0.69", "tokio", "tokio-tungstenite", "tracing", @@ -2490,7 +2501,7 @@ dependencies = [ "ethers-core 2.0.13", "rand", "sha2", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -2509,7 +2520,7 @@ dependencies = [ "ethers-core 2.0.14", "rand", "sha2", - "thiserror 1.0.64", + "thiserror 1.0.69", "tracing", ] @@ -2536,7 +2547,7 @@ dependencies = [ "serde_json", "solang-parser", "svm-rs", - "thiserror 1.0.64", + "thiserror 1.0.69", "tiny-keccak", "tokio", "tracing", @@ -2568,7 +2579,7 @@ dependencies = [ "serde_json", "solang-parser", "svm-rs", - "thiserror 1.0.64", + "thiserror 1.0.69", "tiny-keccak", "tokio", "tracing", @@ -2594,9 +2605,9 @@ checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fastrlp" @@ -2673,9 +2684,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2966,6 +2977,7 @@ dependencies = [ "itertools 0.13.0", "log", "mp2_common", + "mp2_test", "plonky2", "plonky2x", "rand", @@ -2973,7 +2985,7 @@ dependencies = [ "revm", "serde", "serde_json", - "serial_test 3.1.1", + "serial_test 3.2.0", "sha2", "verifiable-db", ] @@ -3034,9 +3046,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ "allocator-api2", "equivalent", @@ -3195,9 +3207,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -3219,9 +3231,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", @@ -3244,7 +3256,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "rustls", "tokio", "tokio-rustls", @@ -3257,7 +3269,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.30", + "hyper 0.14.31", "native-tls", "tokio", "tokio-native-tls", @@ -3271,7 +3283,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-util", "native-tls", "tokio", @@ -3281,16 +3293,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.1", "pin-project-lite", "socket2", "tokio", @@ -3321,6 +3333,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -3329,12 +3459,23 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -3366,13 +3507,13 @@ dependencies = [ [[package]] name = "impl-trait-for-tuples" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" +checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.89", ] [[package]] @@ -3399,7 +3540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.2", "serde", ] @@ -3490,9 +3631,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jammdb" @@ -3521,9 +3662,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "fb15147158e79fd8b8afd0252522769c4f48725460b37338544d8379d94fc8f9" dependencies = [ "wasm-bindgen", ] @@ -3622,7 +3763,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553" dependencies = [ - "regex-automata 0.4.8", + "regex-automata 0.4.9", ] [[package]] @@ -3636,15 +3777,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -3662,6 +3803,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + [[package]] name = "lock_api" version = "0.4.12" @@ -3684,7 +3831,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.0", + "hashbrown 0.15.2", ] [[package]] @@ -3744,11 +3891,10 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", @@ -3765,7 +3911,7 @@ dependencies = [ "eth_trie", "ethereum-types", "ethers 2.0.13", - "hashbrown 0.15.0", + "hashbrown 0.15.2", "hex", "itertools 0.13.0", "log", @@ -3817,7 +3963,7 @@ dependencies = [ "envconfig", "eth_trie", "futures", - "hashbrown 0.15.0", + "hashbrown 0.15.2", "hex", "itertools 0.13.0", "jammdb", @@ -3837,7 +3983,7 @@ dependencies = [ "ryhope", "serde", "serde_json", - "serial_test 3.1.1", + "serial_test 3.2.0", "sqlparser", "test-log", "testfile", @@ -4056,9 +4202,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -4088,9 +4234,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -4135,28 +4281,29 @@ dependencies = [ [[package]] name = "parity-scale-codec" -version = "3.6.12" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "306800abfa29c7f16596b5970a588435e3d5b3149683d00c12b699cc19f895ee" +checksum = "8be4817d39f3272f69c59fe05d0535ae6456c2dc2fa1ba02910296c7e0a5c590" dependencies = [ "arrayvec 0.7.6", "bitvec", "byte-slice-cast", "impl-trait-for-tuples", "parity-scale-codec-derive", + "rustversion", "serde", ] [[package]] name = "parity-scale-codec-derive" -version = "3.6.12" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d830939c76d294956402033aee57a6da7b438f2294eb94864c37b0569053a42c" +checksum = "8781a75c6205af67215f382092b6e0a4ff3734798523e69073d4bcd294ec767b" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.89", ] [[package]] @@ -4267,7 +4414,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" dependencies = [ "memchr", - "thiserror 1.0.64", + "thiserror 1.0.69", "ucd-trie", ] @@ -4344,18 +4491,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf123a161dde1e524adf36f90bc5d8d3462824a9c43553ad07a8183161189ec" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4502d8515ca9f32f1fb543d987f63d95a14934883db45bdb48060b6b69257f8" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", @@ -4364,9 +4511,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -4556,7 +4703,7 @@ dependencies = [ "starkyx", "tokio", "tracing", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -4636,9 +4783,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" -version = "0.2.22" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" +checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", "syn 2.0.89", @@ -4725,9 +4872,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -4855,7 +5002,7 @@ dependencies = [ "plonky2_monolith", "rstest 0.23.0", "serde", - "serial_test 3.1.1", + "serial_test 3.2.0", ] [[package]] @@ -4875,18 +5022,18 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom 0.2.15", "libredox", - "thiserror 1.0.64", + "thiserror 1.0.69", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -4901,9 +5048,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -4942,7 +5089,7 @@ dependencies = [ "h2", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls", "hyper-tls 0.5.0", "ipnet", @@ -4974,9 +5121,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", @@ -4985,7 +5132,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-tls 0.6.0", "hyper-util", "ipnet", @@ -5000,7 +5147,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tokio", "tokio-native-tls", "tower-service", @@ -5210,7 +5357,7 @@ dependencies = [ "rlp", "ruint-macro", "serde", - "thiserror 1.0.64", + "thiserror 1.0.69", "valuable", "zeroize", ] @@ -5259,9 +5406,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -5384,42 +5531,42 @@ dependencies = [ [[package]] name = "scale-info" -version = "2.11.3" +version = "2.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca070c12893629e2cc820a9761bedf6ce1dcddc9852984d1dc734b8bd9bd024" +checksum = "346a3b32eba2640d17a9cb5927056b08f3de90f65b72fe09402c2ad07d684d0b" dependencies = [ "cfg-if", - "derive_more 0.99.18", + "derive_more 1.0.0", "parity-scale-codec", "scale-info-derive", ] [[package]] name = "scale-info-derive" -version = "2.11.3" +version = "2.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d35494501194174bda522a32605929eefc9ecf7e0a326c26db1fdd85881eb62" +checksum = "c6630024bf739e2179b91fb424b28898baf819414262c5d376677dbff1fe7ebf" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.89", ] [[package]] name = "scc" -version = "2.2.2" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c1f7fc6deb21665a9060dfc7d271be784669295a31babdcd4dd2c79ae8cbfb" +checksum = "66b202022bb57c049555430e11fc22fea12909276a80a4c3d368da36ac1d88ed" dependencies = [ "sdd", ] [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -5498,9 +5645,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -5526,9 +5673,9 @@ dependencies = [ [[package]] name = "semver-parser" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" +checksum = "9900206b54a3527fdc7b8a938bffd94a568bac4f4aa8113b209df75a09c0dec2" dependencies = [ "pest", ] @@ -5547,18 +5694,18 @@ checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", @@ -5567,9 +5714,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -5681,16 +5828,16 @@ dependencies = [ [[package]] name = "serial_test" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9" dependencies = [ "futures", "log", "once_cell", "parking_lot", "scc", - "serial_test_derive 3.1.1", + "serial_test_derive 3.2.0", ] [[package]] @@ -5706,9 +5853,9 @@ dependencies = [ [[package]] name = "serial_test_derive" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", @@ -5818,7 +5965,7 @@ checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ "num-bigint 0.4.6", "num-traits", - "thiserror 1.0.64", + "thiserror 1.0.69", "time", ] @@ -5856,9 +6003,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -5874,7 +6021,7 @@ dependencies = [ "lalrpop", "lalrpop-util", "phf", - "thiserror 1.0.64", + "thiserror 1.0.69", "unicode-xid", ] @@ -5909,6 +6056,12 @@ dependencies = [ "log", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "starkyx" version = "0.1.0" @@ -6042,7 +6195,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "thiserror 1.0.64", + "thiserror 1.0.69", "url", "zip", ] @@ -6071,9 +6224,9 @@ dependencies = [ [[package]] name = "syn-solidity" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f76fe0a3e1476bdaa0775b9aec5b869ed9520c2b2fedfe9c6df3618f8ea6290b" +checksum = "da0523f59468a2696391f2a772edc089342aacd53c3caa2ac3264e598edf119b" dependencies = [ "paste", "proc-macro2", @@ -6089,13 +6242,24 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -6150,9 +6314,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", @@ -6214,11 +6378,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl 1.0.64", + "thiserror-impl 1.0.69", ] [[package]] @@ -6232,9 +6396,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", @@ -6311,6 +6475,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -6328,9 +6502,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -6503,9 +6677,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -6514,9 +6688,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -6525,9 +6699,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -6556,9 +6730,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", @@ -6592,7 +6766,7 @@ dependencies = [ "rand", "rustls", "sha1", - "thiserror 1.0.64", + "thiserror 1.0.69", "url", "utf-8", ] @@ -6635,9 +6809,9 @@ checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-normalization" @@ -6690,9 +6864,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.2" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -6705,6 +6879,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -6723,9 +6909,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "serde", ] @@ -6763,7 +6949,7 @@ dependencies = [ "recursion_framework", "ryhope", "serde", - "serial_test 3.1.1", + "serial_test 3.2.0", "tokio", ] @@ -6842,9 +7028,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "21d3b25c3ea1126a2ad5f4f9068483c2af1e64168f847abe863a526b8dbfe00b" dependencies = [ "cfg-if", "once_cell", @@ -6853,9 +7039,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "52857d4c32e496dc6537646b5b117081e71fd2ff06de792e3577a150627db283" dependencies = [ "bumpalo", "log", @@ -6868,21 +7054,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.45" +version = "0.4.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +checksum = "951fe82312ed48443ac78b66fa43eded9999f738f6022e67aead7b708659e49a" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "920b0ffe069571ebbfc9ddc0b36ba305ef65577c94b06262ed793716a1afd981" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6890,9 +7077,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "bf59002391099644be3524e23b781fa43d2be0c5aa0719a18c0731b9d195cab6" dependencies = [ "proc-macro2", "quote", @@ -6903,9 +7090,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "e5047c5392700766601942795a436d7d2599af60dcc3cc1248c9120bfb0827b0" [[package]] name = "wasmtimer" @@ -6923,9 +7110,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "476364ff87d0ae6bfb661053a9104ab312542658c3d8f963b7ace80b6f9b26b9" dependencies = [ "js-sys", "wasm-bindgen", @@ -7195,6 +7382,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "ws_stream_wasm" version = "0.7.4" @@ -7208,7 +7407,7 @@ dependencies = [ "pharos", "rustc_version 0.4.1", "send_wrapper 0.6.0", - "thiserror 1.0.64", + "thiserror 1.0.69", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -7229,6 +7428,30 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -7250,6 +7473,27 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" @@ -7270,6 +7514,28 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/groth16-framework/Cargo.toml b/groth16-framework/Cargo.toml index a8ce89abe..7922499e1 100644 --- a/groth16-framework/Cargo.toml +++ b/groth16-framework/Cargo.toml @@ -23,6 +23,7 @@ itertools.workspace = true rand.workspace = true serial_test.workspace = true sha2.workspace = true +mp2_test = { path = "../mp2-test" } recursion_framework = { path = "../recursion-framework" } verifiable-db = { path = "../verifiable-db" } diff --git a/groth16-framework/tests/common/context.rs b/groth16-framework/tests/common/context.rs index ede68e899..ffb617c81 100644 --- a/groth16-framework/tests/common/context.rs +++ b/groth16-framework/tests/common/context.rs @@ -3,9 +3,11 @@ use super::{NUM_PREPROCESSING_IO, NUM_QUERY_IO}; use groth16_framework::{compile_and_generate_assets, utils::clone_circuit_data}; use mp2_common::{C, D, F}; +use mp2_test::circuit::TestDummyCircuit; use recursion_framework::framework_testing::TestingRecursiveCircuits; use verifiable_db::{ api::WrapCircuitParams, + query::pi_len, revelation::api::Parameters as RevelationParameters, test_utils::{ INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, @@ -40,6 +42,8 @@ impl TestContext { // Generate a fake query circuit set. let query_circuits = TestingRecursiveCircuits::::default(); + let dummy_universal_circuit = + TestDummyCircuit::<{ pi_len::() }>::build(); // Create the revelation parameters. let revelation_params = RevelationParameters::< @@ -52,7 +56,8 @@ impl TestContext { MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_PLACEHOLDERS, >::build( - query_circuits.get_recursive_circuit_set(), + query_circuits.get_recursive_circuit_set(), // unused, so we provide a dummy one + dummy_universal_circuit.circuit_data().verifier_data(), preprocessing_circuits.get_recursive_circuit_set(), preprocessing_circuits .verifier_data_for_input_proofs::<1>() diff --git a/mp2-common/src/eth.rs b/mp2-common/src/eth.rs index ba863d475..ee8eda75b 100644 --- a/mp2-common/src/eth.rs +++ b/mp2-common/src/eth.rs @@ -286,7 +286,7 @@ mod test { types::MAX_BLOCK_LEN, utils::{Endianness, Packer}, }; - use mp2_test::eth::{get_mainnet_url, get_sepolia_url}; + use mp2_test::eth::get_sepolia_url; #[tokio::test] #[ignore] @@ -426,39 +426,6 @@ mod test { Ok(()) } - #[tokio::test] - async fn test_pidgy_pinguin_mapping_slot() -> Result<()> { - // first pinguin holder https://dune.com/queries/2450476/4027653 - // holder: 0x188b264aa1456b869c3a92eeed32117ebb835f47 - // NFT id https://opensea.io/assets/ethereum/0xbd3531da5cf5857e7cfaa92426877b022e612cf8/1116 - let mapping_value = - Address::from_str("0x188B264AA1456B869C3a92eeeD32117EbB835f47").unwrap(); - let nft_id: u32 = 1116; - let mapping_key = left_pad32(&nft_id.to_be_bytes()); - let url = get_mainnet_url(); - let provider = ProviderBuilder::new().on_http(url.parse().unwrap()); - - // extracting from - // https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/token/ERC721/ERC721.sol - // assuming it's using ERC731Enumerable that inherits ERC721 - let mapping_slot = 2; - // pudgy pinguins - let pudgy_address = Address::from_str("0xBd3531dA5CF5857e7CfAA92426877b022e612cf8")?; - let query = ProofQuery::new_mapping_slot(pudgy_address, mapping_slot, mapping_key.to_vec()); - let res = query - .query_mpt_proof(&provider, BlockNumberOrTag::Latest) - .await?; - let raw_address = ProofQuery::verify_storage_proof(&res)?; - // the value is actually RLP encoded ! - let decoded_address: Vec = rlp::decode(&raw_address).unwrap(); - let leaf_node: Vec> = rlp::decode_list(res.storage_proof[0].proof.last().unwrap()); - println!("leaf_node[1].len() = {}", leaf_node[1].len()); - // this is read in the same order - let found_address = Address::from_slice(&decoded_address.into_iter().collect::>()); - assert_eq!(found_address, mapping_value); - Ok(()) - } - #[tokio::test] async fn test_kashish_contract_proof_query() -> Result<()> { // https://sepolia.etherscan.io/address/0xd6a2bFb7f76cAa64Dad0d13Ed8A9EFB73398F39E#code diff --git a/mp2-common/src/group_hashing/mod.rs b/mp2-common/src/group_hashing/mod.rs index 819eb7c2b..47a8822aa 100644 --- a/mp2-common/src/group_hashing/mod.rs +++ b/mp2-common/src/group_hashing/mod.rs @@ -120,6 +120,8 @@ impl ToTargets for QuinticExtensionTarget { } impl FromTargets for CurveTarget { + const NUM_TARGETS: usize = CURVE_TARGET_LEN; + fn from_targets(t: &[Target]) -> Self { assert!(t.len() >= CURVE_TARGET_LEN); let x = QuinticExtensionTarget(t[0..EXTENSION_DEGREE].try_into().unwrap()); diff --git a/mp2-common/src/keccak.rs b/mp2-common/src/keccak.rs index dedbafaf5..e29ba48a9 100644 --- a/mp2-common/src/keccak.rs +++ b/mp2-common/src/keccak.rs @@ -59,6 +59,8 @@ pub type OutputHash = Array; pub type OutputByteHash = Array; impl FromTargets for OutputHash { + const NUM_TARGETS: usize = PACKED_HASH_LEN; + fn from_targets(t: &[Target]) -> Self { OutputHash::from_array(array::from_fn(|i| U32Target(t[i]))) } diff --git a/mp2-common/src/u256.rs b/mp2-common/src/u256.rs index ca62f3eb1..ae6b79cff 100644 --- a/mp2-common/src/u256.rs +++ b/mp2-common/src/u256.rs @@ -500,11 +500,9 @@ impl, const D: usize> CircuitBuilderU256 left: &UInt256Target, right: &UInt256Target, ) -> BoolTarget { - // left <= right iff left - right requires a borrow or left - right == 0 - let (res, borrow) = self.sub_u256(left, right); - let less_than = BoolTarget::new_unsafe(borrow.0); - let is_eq = self.is_zero(&res); - self.or(less_than, is_eq) + // left <= right iff ! right < left + let is_greater = self.is_less_than_u256(right, left); + self.not(is_greater) } fn is_zero(&mut self, target: &UInt256Target) -> BoolTarget { @@ -827,6 +825,7 @@ impl ToTargets for UInt256Target { } impl FromTargets for UInt256Target { + const NUM_TARGETS: usize = NUM_LIMBS; // Expects big endian limbs as the standard format for IO fn from_targets(t: &[Target]) -> Self { Self::new_from_be_target_limbs(&t[..NUM_LIMBS]).unwrap() diff --git a/mp2-common/src/utils.rs b/mp2-common/src/utils.rs index d3b786a9e..3a4193ecb 100644 --- a/mp2-common/src/utils.rs +++ b/mp2-common/src/utils.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, Result}; use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField}; +use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -19,26 +19,25 @@ use sha3::Keccak256; use crate::array::Targetable; use crate::poseidon::{HashableField, H}; +use crate::serialization::circuit_data_serialization::SerializableRichField; use crate::{group_hashing::EXTENSION_DEGREE, types::HashOutput, ProofTuple}; -use crate::{D, F}; const TWO_POWER_8: usize = 256; const TWO_POWER_16: usize = 65536; const TWO_POWER_24: usize = 16777216; -#[allow(dead_code)] -trait ConnectSlice { - fn connect_slice(&mut self, a: &[Target], b: &[Target]); +// check that the closure $f actually panics, printing $msg as error message if the function +// did not panic; this macro is employed in tests in place of #[should_panic] to ensure that a +// panic occurred in the expected function rather than in other parts of the test +#[macro_export] +macro_rules! check_panic { + ($f: expr, $msg: expr) => {{ + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe($f)); + assert!(result.is_err(), $msg); + }}; } -impl ConnectSlice for CircuitBuilder { - fn connect_slice(&mut self, a: &[Target], b: &[Target]) { - assert_eq!(a.len(), b.len()); - for (ai, bi) in a.iter().zip(b) { - self.connect(*ai, *bi); - } - } -} +pub use check_panic; pub fn verify_proof_tuple< F: RichField + Extendable, @@ -326,7 +325,7 @@ pub fn pack_and_compute_poseidon_target, const b.hash_n_to_hash_no_pad::(packed) } -pub trait SelectHashBuilder { +pub trait HashBuilder { /// Select `first_hash` or `second_hash` as output depending on the Boolean `cond` fn select_hash( &mut self, @@ -334,9 +333,12 @@ pub trait SelectHashBuilder { first_hash: &HashOutTarget, second_hash: &HashOutTarget, ) -> HashOutTarget; + + /// Determine whether `first_hash == second_hash` + fn hash_eq(&mut self, first_hash: &HashOutTarget, second_hash: &HashOutTarget) -> BoolTarget; } -impl, const D: usize> SelectHashBuilder for CircuitBuilder { +impl, const D: usize> HashBuilder for CircuitBuilder { fn select_hash( &mut self, cond: BoolTarget, @@ -352,6 +354,28 @@ impl, const D: usize> SelectHashBuilder for Circuit .collect_vec(), ) } + + fn hash_eq(&mut self, first_hash: &HashOutTarget, second_hash: &HashOutTarget) -> BoolTarget { + let _true = self._true(); + first_hash + .elements + .iter() + .zip(second_hash.elements.iter()) + .fold(_true, |acc, (first, second)| { + let is_eq = self.is_equal(*first, *second); + self.and(acc, is_eq) + }) + } +} + +pub trait SelectTarget { + /// Return `first` if `cond` is true, `second` otherwise + fn select, const D: usize>( + b: &mut CircuitBuilder, + cond: &BoolTarget, + first: &Self, + second: &Self, + ) -> Self; } pub trait ToFields { @@ -414,10 +438,16 @@ impl Fieldable for u64 { } pub trait FromTargets { + /// Number of targets necessary to instantiate `Self` + const NUM_TARGETS: usize; + + /// Number of targets in `t` must be at least `Self::NUM_TARGETS` fn from_targets(t: &[Target]) -> Self; } impl FromTargets for HashOutTarget { + const NUM_TARGETS: usize = NUM_HASH_OUT_ELTS; + fn from_targets(t: &[Target]) -> Self { HashOutTarget { elements: create_array(|i| t[i]), diff --git a/mp2-v1/src/query/batching_planner.rs b/mp2-v1/src/query/batching_planner.rs new file mode 100644 index 000000000..7ce4125c9 --- /dev/null +++ b/mp2-v1/src/query/batching_planner.rs @@ -0,0 +1,481 @@ +use anyhow::Result; +use std::{collections::BTreeSet, fmt::Debug, hash::Hash}; + +use alloy::primitives::U256; +use futures::{stream, StreamExt}; +use hashbrown::HashMap; +use itertools::Itertools; +use parsil::symbols::ContextProvider; +use ryhope::{ + storage::{updatetree::UpdateTree, WideLineage}, + Epoch, +}; +use serde::{Deserialize, Serialize}; +use verifiable_db::query::{ + api::{NodePath, RowInput, TreePathInputs}, + computational_hash_ids::ColumnIDs, + universal_circuit::universal_circuit_inputs::{ColumnCell, RowCells}, +}; + +use crate::{ + indexing::{ + block::{BlockPrimaryIndex, BlockTreeKey}, + index::IndexNode, + row::{RowPayload, RowTreeKey}, + }, + query::planner::TreeFetcher, +}; + +use super::planner::NonExistenceInput; + +async fn compute_input_for_row>>( + tree: &T, + row_key: &RowTreeKey, + index_value: BlockPrimaryIndex, + index_path: &TreePathInputs, + column_ids: &ColumnIDs, +) -> RowInput { + let row_path = tree + .compute_path(row_key, index_value as Epoch) + .await + .unwrap_or_else(|| panic!("node with key {:?} not found in cache", row_key)); + let path = NodePath::new(row_path, index_path.clone()); + let (_, row_payload) = tree + .fetch_ctx_and_payload_at(row_key, index_value as Epoch) + .await + .unwrap_or_else(|| panic!("node with key {:?} not found in cache", row_key)); + // build row cells + let primary_index_cell = ColumnCell::new(column_ids.primary_column(), U256::from(index_value)); + let secondary_index_cell = ColumnCell::new( + column_ids.secondary_column(), + row_payload.secondary_index_value(), + ); + let non_indexed_cells = column_ids + .non_indexed_columns() + .into_iter() + .filter_map(|id| { + row_payload + .cells + .find_by_column(id) + .map(|info| ColumnCell::new(id, info.value)) + }) + .collect::>(); + let row_cells = RowCells::new(primary_index_cell, secondary_index_cell, non_indexed_cells); + RowInput::new(&row_cells, &path) +} + +/// Given the subtree built from the rows satisyfing the query ranges on primary and +/// secondary indexes, this method splits the rows in chunks of `CHUNK_SIZE` consecutive +/// rows, with all the rows in the same chunk being proven all together in the same +/// circuit. The method also builds the `UpdateTree` that specifies how to recursively +/// aggregate all these chunks, using the chunk aggregation circuit. The `NUM_CHUNKS` +/// constant corresponds to the maximum number of chunks that can be aggregated by such +/// circuit, and will thus correspond to the arity of the constructed `UpdateTree`. +/// The method requires the following inputs: +/// - `row_cache` : Wide lineage of rows tree nodes in the subtree build from the rows +/// satisfying the the query ranges on primary and secondary indexes +/// - `index_cache` : Wide lineage of index tree nodes in the subtree build from the rows +/// satisfying the the query ranges on primary and secondary indexes +/// - `column_ids` : Identifiers of the columns of the table, including primary and +/// secondary indexes columns +/// - `non_existence_inputs` : This set of data is employed to find the proper row to be +/// proven for a rows tree that contains no rows with a secondary index value lying +/// in the query range over secondary index, which still needs to be proven for +/// completeness (i.e., proving that we are not skipping potentially matching rows +/// for the query); this data structure can be instantiated with its own `new` method +/// - `epoch` : Last epoch inserted in the index tree +pub async fn generate_chunks_and_update_tree< + const CHUNK_SIZE: usize, + const NUM_CHUNKS: usize, + C: ContextProvider, +>( + row_cache: WideLineage>, + index_cache: WideLineage>, + column_ids: &ColumnIDs, + non_existence_inputs: NonExistenceInput<'_, C>, + epoch: Epoch, +) -> Result<( + HashMap, Vec>, + UTForChunks, +)> { + let chunks = + generate_chunks::(row_cache, index_cache, column_ids, non_existence_inputs) + .await?; + Ok(UTForChunksBuilder { chunks }.build_update_tree_with_base_chunks(epoch)) +} + +async fn generate_chunks( + row_cache: WideLineage>, + index_cache: WideLineage>, + column_ids: &ColumnIDs, + non_existence_inputs: NonExistenceInput<'_, C>, +) -> Result>> { + let index_keys_by_epochs = index_cache.keys_by_epochs(); + assert_eq!(index_keys_by_epochs.len(), 1); + let row_keys_by_epochs = row_cache.keys_by_epochs(); + let current_epoch = *index_keys_by_epochs.keys().next().unwrap() as Epoch; + let sorted_index_values = index_keys_by_epochs[¤t_epoch] + .iter() + .cloned() + .collect::>(); + + let prove_rows = async |index_value| { + let index_path = index_cache + .compute_path(&index_value, current_epoch) + .await + .unwrap_or_else(|| panic!("node with key {index_value} not found in index tree cache")); + let proven_rows = if let Some(matching_rows) = + row_keys_by_epochs.get(&(index_value as Epoch)) + { + let sorted_rows = matching_rows.iter().collect::>(); + stream::iter(sorted_rows.iter()) + .then(async |&row_key| { + compute_input_for_row(&row_cache, row_key, index_value, &index_path, column_ids) + .await + }) + .collect::>() + .await + } else { + let proven_node = non_existence_inputs + .find_row_node_for_non_existence(index_value) + .await + .unwrap_or_else(|_| { + panic!("node for non-existence not found for index value {index_value}") + }); + let row_input = compute_input_for_row( + non_existence_inputs.row_tree, + &proven_node, + index_value, + &index_path, + column_ids, + ) + .await; + vec![row_input] + }; + proven_rows + }; + + // TODO: This implementation causes an error in DQ: + // `implementation of `std::marker::Send` is not general enough` + /* + let chunks = stream::iter(sorted_index_values.into_iter()) + .then(prove_rows) + .concat() + .await + */ + let mut chunks = vec![]; + for index_value in sorted_index_values { + let chunk = prove_rows(index_value).await; + chunks.extend(chunk); + } + + let chunks = chunks + .chunks(CHUNK_SIZE) + .map(|chunk| chunk.to_vec()) + .collect_vec(); + + Ok(chunks) +} + +/// Key for nodes of the `UTForChunks` employed to +/// prove chunks of rows. +/// The key is composed by 2 integers: +/// - The `level` of the node in the `UpdateTree`, that is the number of +/// ancestor nodes between the node and the root of the tree +/// - The `position` of the node in the tree among the nodes with the same +/// `level`. The position is basically an identifier to uniquely identify +/// a node among all the nodes in the same level. It is computed recursively +/// from the position `parent_pos` of the parent node and the number of left +/// siblings `num_left` of the node as `parent_pos*ARITY + num_left` +/// +/// For instance, consider the following tree, with arity 3: +/// ```text +/// A +/// +/// B C D +/// +/// E F G H I +/// ``` +/// The nodes in this tree will be identified by the following keys: +/// ```text +/// (0,0) +/// +/// (1,0) (1,1) (1,2) +/// +/// (2,0) (2,1) (2,2) (2,3) (2,4) +/// ``` +#[derive( + Clone, Copy, Debug, Default, PartialEq, PartialOrd, Ord, Eq, Hash, Serialize, Deserialize, +)] +pub struct UTKey(pub (usize, usize)); + +impl UTKey { + /// Compute the key of the child node of `self` that has `num_left_siblings` + /// left siblings + fn children_key(&self, num_left_siblings: usize) -> Self { + let Self((parent_level, parent_pos)) = self; + Self((*parent_level + 1, *parent_pos * ARITY + num_left_siblings)) + } +} + +/// `UpdateTree` employed to prove chunks and aggregate chunks of rows +/// into a single proof. The tree is built employing a `ProvingTree` +/// as the skeleton tree, which determines the structure of the tree. +pub type UTForChunks = UpdateTree>; + +/// Data atructure employed to build the `UpdateTreeForChunks` for the set of chunks +#[derive(Clone, Debug)] +struct UTForChunksBuilder { + chunks: Vec>, +} + +/// Convenience trait, used just to implement the public methods to be exposed +/// for the `UpdateTreeForChunks` type alias +pub trait UTForChunkProofs { + type K: Clone + Debug + Eq + PartialEq + Hash; + + /// Get the keys of the children nodes in the update tree + /// of the node with key `node_key` + fn get_children_keys(&self, node_key: &Self::K) -> Vec; +} + +impl UTForChunkProofs for UTForChunks { + type K = UTKey; + + fn get_children_keys(&self, node_key: &Self::K) -> Vec { + (0..NUM_CHUNKS) + .filter_map(|i| { + // first, compute the child key for the i-th potential child + let child_key = node_key.children_key(i); + // then, return the computed key only if the i-th child exists in the tree + self.node_from_key(&child_key).map(|_| child_key) + }) + .collect_vec() + } +} + +/// Tree employed as a skeleton to build the `UTForChunks`, which is +/// employed to prove and aggregate rows chunks. Each node in the tree corresponds +/// to a proof being generated: +/// - Leaf nodes are associated to the proving of a single row chunk +/// - Internal nodes are associated to the proving of aggregation of multiple row chunks, +/// and so ARITY of the tree corresponds to the maximum number of chunks that can be +/// aggregated in a single proof +/// +/// Given the number of leaves `n`, which correspond to the number of chunks to be aggregated, +/// the tree is built in such a way to minimize the number of internal nodes, hereby +/// minimzing the number of proofs to be generated. The overall idea is: +/// - Place as many leaves as possible in `full` subtrees. A full subtree is defined as +/// a subtree containing `ARITY^exp` leaves, for an `exp >= 0`. In particular, it is +/// always possible to build at least one full subtree with `exp = ceil(log_{ARITY}(n))-1`. +/// Note that, depending on `n`, it might be possible to build from one up to `ARITY` full +/// subtrees, each containing `ARITY^exp` number of leaves +/// - If there are leaves that cannot be placed inside a full subtree, then by construction +/// at most `ARITY-1` full subtrees have been built and placed as child nodes of the root, +/// and so there are still `m >= 1` spots available among the children of the root. +/// So, up to `m-1` remaining leaves are placed as direct children of the root; if there +/// are more than `m-1` remaining leaves, they are placed in a subtree, built +/// recursively using the same logic, which is placed as a further child of the root +/// +/// More details on the algorithm to construct a tree can be found in the `build_subtree` +/// method +#[derive(Clone, Debug)] +struct ProvingTree { + // all the nodes of the tree, indexed by the key of the node + nodes: HashMap, ProvingTreeNode>, + // leaves of the tree, identified by their key. The leaves are inserted in + // this vector in order (i.e, from left to right in the tree) when building + // the tree. The position of a leaf in this vector is referred to as + // `leaf_index` + leaves: Vec>, +} + +/// Node of the proving tree, containing the keys of the parent node and +/// of the children +#[derive(Clone, Debug)] +struct ProvingTreeNode { + parent_key: Option>, + children_keys: Vec>, +} + +impl ProvingTree { + /// Build a new `ProvingTree` with `num_leaves` leaf nodes + fn new(num_leaves: usize) -> Self { + let mut tree = ProvingTree { + nodes: HashMap::new(), + leaves: vec![], + }; + if num_leaves > 0 { + // build a subtree for `num_leaves` + tree.build_subtree(num_leaves, None); + } + + tree + } + + /// Insert a node as a child node of the node with key `parent_node_key`. + /// The node is inserted as root if `parent_node_key` is `None` + fn insert_as_child_of(&mut self, parent_node_key: Option<&UTKey>) -> UTKey { + if let Some(parent_key) = parent_node_key { + // get parent node + let parent_node = self.nodes.get_mut(parent_key).unwrap_or_else(|| { + panic!( + "Providing a non-existing parent key for insertion: {:?}", + parent_key + ) + }); + // get number of existing children for the parent node, which is needed to compute + // the key of the child to be inserted + let num_childrens = parent_node.children_keys.len(); + let new_child_key = parent_key.children_key(num_childrens); + let child_node = ProvingTreeNode { + parent_key: Some(*parent_key), + children_keys: vec![], + }; + // insert new child in the set of children of the parent + parent_node.children_keys.push(new_child_key); + assert!( + self.nodes.insert(new_child_key, child_node).is_none(), + "Node with key {:?} already found in the tree", + new_child_key + ); + new_child_key + } else { + // insert as root + let root = ProvingTreeNode { + parent_key: None, + children_keys: vec![], + }; + let root_key = UTKey((0, 0)); + assert!( + self.nodes.insert(root_key, root).is_none(), + "Error: root node inserted multiple times" + ); + root_key + } + } + + /// Build a full subtree containing `num_leaves` leaf nodes. The subtree + /// is full since `num_leaves` is expected to be `ARITY^exp`, for `exp >= 0`. + /// `parent_node_key` is the key of the parent node of the root of the subtree + fn build_full_subtree(&mut self, num_leaves: usize, parent_node_key: &UTKey) { + let root_key = self.insert_as_child_of(Some(parent_node_key)); + if num_leaves > 1 { + for _ in 0..ARITY { + self.build_full_subtree(num_leaves / ARITY, &root_key); + } + } else { + // it's a leaf node, so we add it to leaves + self.leaves.push(root_key); + } + } + + /// Build a subtree containing `num_leaves` leaf nodes. + /// `parent_node_key` is the key of the parent node of the root of the subtree, if any + fn build_subtree(&mut self, num_leaves: usize, parent_node_key: Option<&UTKey>) { + let root_key = self.insert_as_child_of(parent_node_key); + if num_leaves == 1 { + // we are done, we just add the root node as a leaf + return self.leaves.push(root_key); + } + // we compute the number of full subtrees we can employ to place leaves. + // A full subtree is a subtree that contains ARITY^exp leaves for some exp >= 0. + // Given `num_leaves`, we know we can always build at least 1 full subtree + // for `exp =ceil(log_{ARITY}(num_leaves))-1`, i.e., a full subtree + // containing `ARITY^exp` leaves. + let num_leaves_in_subtree = smallest_greater_power::(num_leaves) / ARITY; + let num_full_subtrees = num_leaves / num_leaves_in_subtree; + for _ in 0..num_full_subtrees { + self.build_full_subtree(num_leaves_in_subtree, &root_key); + } + // overall number of leaves placed in the `num_full_subtrees` full subtrees + let inserted_leaves = num_leaves_in_subtree * num_full_subtrees; + let remaining_leaves = num_leaves - inserted_leaves; + // number of nodes still available at the current level + let available_nodes_at_level = ARITY - num_full_subtrees; + // compute the number of leaves to be placed at the current level + let num_leaves_at_level = if remaining_leaves > available_nodes_at_level { + // we place `available_nodes_at_level - 1` leaf nodes in this level, + // while all the other remaining nodes are placed in a subtree, which is built + // recursively + let num_leaves_at_level = available_nodes_at_level - 1; + self.build_subtree(remaining_leaves - num_leaves_at_level, Some(&root_key)); + num_leaves_at_level + } else { + // we can place all remaining nodes at this level as leaf nodes + remaining_leaves + }; + // place the leaves at the current level + for _ in 0..num_leaves_at_level { + let leaf_key = self.insert_as_child_of(Some(&root_key)); + self.leaves.push(leaf_key); + } + } + + /// Compute the path, from root to leaf, for the leaf with index `leaf_index` + /// in `self` tree + fn compute_path_for_leaf(&self, leaf_index: usize) -> Vec> { + let leaf_key = &self.leaves[leaf_index]; + let mut path = vec![]; + let mut node_key = Some(leaf_key); + while node_key.is_some() { + // place node key in the path + let key = node_key.unwrap(); + path.push(*key); + // fetch key of the parent node, if any + node_key = self + .nodes + .get(key) + .unwrap_or_else(|| panic!("Node with key {:?} not found", key)) + .parent_key + .as_ref(); + } + + path.reverse(); + path + } +} + +impl UTForChunksBuilder { + /// This method builds an `UpdateTree` to prove and aggregate the set of chunks + /// provided as input. It also returns the set of chunks to be proven, with each + /// chunk being associated to the key of the node in the `UpdateTree` corresponding + /// to the proving task for that chunk + fn build_update_tree_with_base_chunks( + self, + epoch: Epoch, + ) -> ( + HashMap, Vec>, + UTForChunks, + ) { + let num_chunks = self.chunks.len(); + let tree = ProvingTree::::new(num_chunks); + let (chunks_with_keys, paths): (HashMap<_, _>, Vec<_>) = self + .chunks + .into_iter() + .enumerate() + .map(|(node_index, chunk)| { + let path = tree.compute_path_for_leaf(node_index); + ( + ( + *path.last().unwrap(), // chunk node is always a leaf of the tree, so it is the last node + // in the path + chunk, + ), + path, + ) + }) + .unzip(); + (chunks_with_keys, UpdateTree::from_paths(paths, epoch)) + } +} + +// Method to compute the smallest power of `BASE` greater than the provided `input`. +// In other words, it computes `BASE^ceil(log_BASE(input))`` +fn smallest_greater_power(input: usize) -> usize { + let mut pow = 1usize; + while pow < input { + pow *= BASE; + } + pow +} diff --git a/mp2-v1/src/query/mod.rs b/mp2-v1/src/query/mod.rs index 5e480858b..3a32e4610 100644 --- a/mp2-v1/src/query/mod.rs +++ b/mp2-v1/src/query/mod.rs @@ -1 +1,2 @@ +pub mod batching_planner; pub mod planner; diff --git a/mp2-v1/src/query/planner.rs b/mp2-v1/src/query/planner.rs index 99be0419b..7426da087 100644 --- a/mp2-v1/src/query/planner.rs +++ b/mp2-v1/src/query/planner.rs @@ -2,6 +2,7 @@ use alloy::primitives::U256; use anyhow::Context; use bb8::Pool; use bb8_postgres::PostgresConnectionManager; +use core::hash::Hash; use futures::stream::TryStreamExt; use itertools::Itertools; use mp2_common::types::HashOutput; @@ -10,14 +11,17 @@ use ryhope::{ storage::{ pgsql::{PgsqlStorage, ToFromBytea}, updatetree::UpdateTree, - FromSettings, PayloadStorage, TransactionalStorage, TreeStorage, + FromSettings, PayloadStorage, TransactionalStorage, TreeStorage, WideLineage, }, tree::{MutableTree, NodeContext, TreeTopology}, Epoch, MerkleTreeKvDb, NodePayload, }; -use std::fmt::Debug; +use std::{fmt::Debug, future::Future}; use tokio_postgres::{row::Row as PsqlRow, types::ToSql, NoTls}; -use verifiable_db::query::aggregation::{NodeInfo, QueryBounds}; +use verifiable_db::query::{ + api::TreePathInputs, + utils::{ChildPosition, NodeInfo, QueryBounds}, +}; use crate::indexing::{ block::BlockPrimaryIndex, @@ -35,50 +39,370 @@ pub struct NonExistenceInfo { pub proving_plan: UpdateTree, } -/// Returns the proving plan to prove the non existence of node of the query in this row tree at -/// the epoch primary. It also returns the leaf node chosen. -/// -/// The row tree is given and specialized to psql storage since that is the only official storage -/// supported. -/// The `table_name` must be the one given to parsil settings, it is the human friendly table -/// name, i.e. the vTable name. -/// The pool is to issue specific query -/// Primary is indicating the primary index over which this row tree is looked at. -/// Settings are the parsil settings corresponding to the current SQL and current table looked at. -/// Pis contain the bounds and placeholders values. -/// TODO: we should extend ryhope to offer this API directly on the tree since it's very related. -pub async fn find_row_node_for_non_existence( - row_tree: &MerkleTreeKvDb, DBRowStorage>, - table_name: String, - pool: &DBPool, - primary: BlockPrimaryIndex, - settings: &ParsilSettings, - bounds: &QueryBounds, -) -> anyhow::Result<(RowTreeKey, UpdateTree)> +#[derive(Clone)] +pub struct NonExistenceInput<'a, C: ContextProvider> { + pub(crate) row_tree: &'a MerkleTreeKvDb, DBRowStorage>, + pub(crate) table_name: String, + pub(crate) pool: &'a DBPool, + pub(crate) settings: &'a ParsilSettings, + pub(crate) bounds: QueryBounds, +} + +impl<'a, C: ContextProvider> NonExistenceInput<'a, C> { + pub fn new( + row_tree: &'a MerkleTreeKvDb, DBRowStorage>, + table_name: String, + pool: &'a DBPool, + settings: &'a ParsilSettings, + bounds: &'a QueryBounds, + ) -> Self { + Self { + row_tree, + table_name, + pool, + settings, + bounds: bounds.clone(), + } + } + + pub async fn find_row_node_for_non_existence( + &self, + primary: BlockPrimaryIndex, + ) -> anyhow::Result { + let (query_for_min, query_for_max) = bracket_secondary_index( + &self.table_name, + self.settings, + primary as Epoch, + &self.bounds, + ); + + // try first with lower node than secondary min query bound + let to_be_proven_node = + match find_node_for_proof(self.pool, self.row_tree, query_for_min, primary, true) + .await? + { + Some(node) => node, + None => { + find_node_for_proof(self.pool, self.row_tree, query_for_max, primary, false) + .await? + .expect("No valid node found to prove non-existence, something is wrong") + } + }; + + Ok(to_be_proven_node) + } +} + +pub trait TreeFetcher: Sized { + /// Constant flag specifying whether the implementor is a `WideLineage` or not + const IS_WIDE_LINEAGE: bool; + + fn fetch_ctx_and_payload_at( + &self, + k: &K, + epoch: Epoch, + ) -> impl Future, V)>> + Send; + + fn compute_path( + &self, + node_key: &K, + epoch: Epoch, + ) -> impl Future> { + async move { + let (node_ctx, node_payload) = self.fetch_ctx_and_payload_at(node_key, epoch).await?; + let mut current_node_key = node_ctx.parent.clone(); + let mut previous_node_key = node_key.clone(); + let mut path = vec![]; + while current_node_key.is_some() { + let (ctx, payload) = self + .fetch_ctx_and_payload_at(current_node_key.as_ref().unwrap(), epoch) + .await + .unwrap_or_else(|| { + panic!("node with key {:?} not found in tree", current_node_key) + }); + let child_position = match ctx + .iter_children() + .find_position(|child| { + child.is_some() && child.unwrap().clone() == previous_node_key + }) + .unwrap() + .0 + { + 0 => ChildPosition::Left, + 1 => ChildPosition::Right, + _ => unreachable!(), + }; + previous_node_key = current_node_key.unwrap(); + current_node_key = ctx.parent.clone(); + let node_info = self.compute_node_info(ctx, payload, epoch).await; + path.push((node_info, child_position)); + } + let (node_info, left_child, right_child) = + get_node_info_from_ctx_and_payload(self, node_ctx, node_payload, epoch).await; + + Some(TreePathInputs::new( + node_info, + path, + [left_child, right_child], + )) + } + } + + fn compute_node_info( + &self, + node_ctx: NodeContext, + node_payload: V, + at: Epoch, + ) -> impl Future { + async move { + let child_hash = async |k: Option| -> Option { + match k { + Some(child_key) => self + .fetch_ctx_and_payload_at(&child_key, at) + .await + .map(|(_ctx, payload)| payload.hash()), + None => None, + } + }; + + let left_child_hash = child_hash(node_ctx.left).await; + let right_child_hash = child_hash(node_ctx.right).await; + NodeInfo::new( + &node_payload.embedded_hash(), + left_child_hash.as_ref(), + right_child_hash.as_ref(), + node_payload.value(), + node_payload.min(), + node_payload.max(), + ) + } + } + + /// This method computes the successor of the node with context `node_ctx` in the input `tree` + /// at the given `epoch`. It returns the context of the successor node and its payload + fn get_successor( + &self, + node_ctx: &NodeContext, + epoch: Epoch, + ) -> impl Future, V)>> + where + K: Clone + Debug + Eq + PartialEq, + { + async move { + if node_ctx.right.is_some() { + if let Some((right_child_ctx, right_child_payload)) = + fetch_existing_node_from_tree(self, node_ctx.right.as_ref().unwrap(), epoch) + .await + { + // find successor in the subtree rooted in the right child: it is + // the leftmost node in such a subtree + let (mut successor_ctx, mut successor_payload) = + (right_child_ctx, right_child_payload); + while successor_ctx.left.is_some() { + let Some((ctx, payload)) = fetch_existing_node_from_tree( + self, + successor_ctx.left.as_ref().unwrap(), + epoch, + ) + .await + else { + // we don't found the left child node in the tree, which means that the + // successor might be out of range, so we return None + return None; + }; + successor_ctx = ctx; + successor_payload = payload; + } + Some((successor_ctx, successor_payload)) + } else { + // we don't found the right child node in the tree, which means that the + // successor might be out of range, so we return None + return None; + } + } else { + // find successor among the ancestors of current node: we go up in the path + // until we either found a node whose left child is the previous node in the + // path, or we get to the root of the tree + let mut candidate_successor_ctx = node_ctx.clone(); + let mut successor = None; + while candidate_successor_ctx.parent.is_some() { + let (parent_ctx, parent_payload) = self + .fetch_ctx_and_payload_at( + candidate_successor_ctx.parent.as_ref().unwrap(), + epoch, + ) + .await + .unwrap_or_else(|| { + panic!( + "Node context not found for parent of node {:?}", + candidate_successor_ctx.node_id + ) + }); + if parent_ctx + .iter_children() + .find_position(|child| { + child.is_some() + && child.unwrap().clone() == candidate_successor_ctx.node_id + }) + .unwrap() + .0 + == 0 + { + // successor_ctx.node_id is left child of parent_ctx node, so parent_ctx is + // the successor + successor = Some((parent_ctx, parent_payload)); + break; + } else { + candidate_successor_ctx = parent_ctx; + } + } + successor + } + } + } + + fn get_predecessor( + &self, + node_ctx: &NodeContext, + epoch: Epoch, + ) -> impl Future, V)>> + where + K: Clone + Debug + Eq + PartialEq, + { + async move { + if node_ctx.left.is_some() { + if let Some((left_child_ctx, left_child_payload)) = + fetch_existing_node_from_tree(self, node_ctx.left.as_ref().unwrap(), epoch) + .await + { + // find predecessor in the subtree rooted in the left child: it is + // the rightmost node in such a subtree + let (mut predecessor_ctx, mut predecessor_payload) = + (left_child_ctx, left_child_payload); + while predecessor_ctx.right.is_some() { + let Some((ctx, payload)) = fetch_existing_node_from_tree( + self, + predecessor_ctx.right.as_ref().unwrap(), + epoch, + ) + .await + else { + // we don't found the right child node in the tree, which means that the + // predecessor might be out of range, so we return None + return None; + }; + predecessor_ctx = ctx; + predecessor_payload = payload; + } + Some((predecessor_ctx, predecessor_payload)) + } else { + // we don't found the left child node in the tree, which means that the + // predecessor might be out of range, so we return None + return None; + } + } else { + // find predecessor among the ancestors of current node: we go up in the path + // until we either found a node whose right child is the previous node in the + // path, or we get to the root of the tree + let mut candidate_predecessor_ctx = node_ctx.clone(); + let mut predecessor = None; + while candidate_predecessor_ctx.parent.is_some() { + let (parent_ctx, parent_payload) = self + .fetch_ctx_and_payload_at( + candidate_predecessor_ctx.parent.as_ref().unwrap(), + epoch, + ) + .await + .unwrap_or_else(|| { + panic!( + "Node context not found for parent of node {:?}", + candidate_predecessor_ctx.node_id + ) + }); + if parent_ctx + .iter_children() + .find_position(|child| { + child.is_some() + && child.unwrap().clone() == candidate_predecessor_ctx.node_id + }) + .unwrap() + .0 + == 1 + { + // predecessor_ctx.node_id is right child of parent_ctx node, so parent_ctx is + // the predecessor + predecessor = Some((parent_ctx, parent_payload)); + break; + } else { + candidate_predecessor_ctx = parent_ctx; + } + } + predecessor + } + } + } +} + +impl TreeFetcher for WideLineage where - C: ContextProvider, + K: Debug + Hash + Eq + Clone + Sync + Send, { - let (query_for_min, query_for_max) = - bracket_secondary_index(&table_name, settings, primary as Epoch, bounds); + const IS_WIDE_LINEAGE: bool = true; - // try first with lower node than secondary min query bound - let to_be_proven_node = - match find_node_for_proof(pool, row_tree, query_for_min, primary, true).await? { - Some(node) => node, - None => find_node_for_proof(pool, row_tree, query_for_max, primary, false) - .await? - .expect("No valid node found to prove non-existence, something is wrong"), - }; + async fn fetch_ctx_and_payload_at(&self, k: &K, epoch: Epoch) -> Option<(NodeContext, V)> { + self.ctx_and_payload_at(epoch, k) + } +} - let path = row_tree - // since the epoch starts at genesis we can directly give the block number ! - .lineage_at(&to_be_proven_node, primary as Epoch) - .await? - .expect("node doesn't have a lineage?") - .into_full_path() - .collect_vec(); - let proving_tree = UpdateTree::from_paths([path], primary as Epoch); - Ok((to_be_proven_node.clone(), proving_tree)) +impl< + V: NodePayload + Send + Sync + LagrangeNode, + T: TreeTopology + MutableTree + 'static, + S: TransactionalStorage + + TreeStorage + + PayloadStorage + + FromSettings + + 'static, + > TreeFetcher for MerkleTreeKvDb +{ + const IS_WIDE_LINEAGE: bool = false; + + async fn fetch_ctx_and_payload_at( + &self, + k: &T::Key, + epoch: Epoch, + ) -> Option<(NodeContext, V)> { + self.try_fetch_with_context_at(k, epoch) + .await + .expect("Failed to fetch context") + } +} + +/// Fetch a key `k` from a tree, assuming that the key is in the +/// tree. Therefore, it handles differently the case when `k` is not found: +/// - If `T::WIDE_LINEAGE` is true, then `k` might not be found because the +/// node associated to key `k` is in the tree, but not in the lineage +/// - Otherwise, it panics because it's not expected to happen, as we are +/// assuming to call this method only on keys which are in the tree +async fn fetch_existing_node_from_tree>( + tree: &T, + k: &K, + epoch: Epoch, +) -> Option<(NodeContext, V)> +where + K: Clone + Debug + Eq + PartialEq, +{ + if T::IS_WIDE_LINEAGE { + // we simply return the result, since in case of `WideLineage` + // fetching might fail because the node was not in the lineage + tree.fetch_ctx_and_payload_at(k, epoch).await + } else { + // Otherwise, we are fetching from an entire tree, so + Some( + tree.fetch_ctx_and_payload_at(k, epoch) + .await + .unwrap_or_else(|| panic!("Node context not found for node {:?}", k)), + ) + } } // this method returns the `NodeContext` of the successor of the node provided as input, @@ -89,81 +413,19 @@ async fn get_successor_node_with_same_value( node_ctx: &NodeContext, value: U256, primary: BlockPrimaryIndex, -) -> anyhow::Result>> { - if node_ctx.right.is_some() { - let (right_child_ctx, payload) = row_tree - .fetch_with_context_at(node_ctx.right.as_ref().unwrap(), primary as Epoch) - .await? - .expect("right is checked to be Some"); - // the value of the successor in this case is `payload.min`, since the successor is the - // minimum of the subtree rooted in the right child - if payload.min() != value { - // the value of successor is different from `value`, so we don't return the - // successor node - return Ok(None); - } - // find successor in the subtree rooted in the right child: it is - // the leftmost node in such a subtree - let mut successor_ctx = right_child_ctx; - while successor_ctx.left.is_some() { - successor_ctx = row_tree - .node_context_at(successor_ctx.left.as_ref().unwrap(), primary as Epoch) - .await? - .unwrap_or_else(|| { - panic!( - "Node context not found for left child of node {:?}", - successor_ctx.node_id - ) - }); - } - Ok(Some(successor_ctx)) - } else { - // find successor among the ancestors of current node: we go up in the path - // until we either found a node whose left child is the previous node in the - // path, or we get to the root of the tree - let (mut candidate_successor_ctx, mut candidate_successor_val) = (node_ctx.clone(), value); - let mut successor_found = false; - while candidate_successor_ctx.parent.is_some() { - let (parent_ctx, parent_payload) = row_tree - .fetch_with_context_at( - candidate_successor_ctx.parent.as_ref().unwrap(), - primary as Epoch, - ) - .await? - .unwrap(); - candidate_successor_val = parent_payload.value(); - if parent_ctx - .iter_children() - .find_position(|child| { - child.is_some() && child.unwrap().clone() == candidate_successor_ctx.node_id - }) - .unwrap() - .0 - == 0 - { - // successor_ctx.node_id is left child of parent_ctx node, so parent_ctx is - // the successor - candidate_successor_ctx = parent_ctx; - successor_found = true; - break; - } else { - candidate_successor_ctx = parent_ctx; - } - } - if successor_found { - if candidate_successor_val != value { +) -> Option> { + row_tree + .get_successor(node_ctx, primary as Epoch) + .await + .and_then(|(successor_ctx, successor_payload)| { + if successor_payload.value() != value { // the value of successor is different from `value`, so we don't return the // successor node - return Ok(None); + None + } else { + Some(successor_ctx) } - Ok(Some(candidate_successor_ctx)) - } else { - // We got up to the root of the tree without finding the successor, - // which means that the input node has no successor; - // so we don't return any node - Ok(None) - } - } + }) } // this method returns the `NodeContext` of the predecessor of the node provided as input, @@ -174,82 +436,19 @@ async fn get_predecessor_node_with_same_value( node_ctx: &NodeContext, value: U256, primary: BlockPrimaryIndex, -) -> anyhow::Result>> { - if node_ctx.left.is_some() { - let (left_child_ctx, payload) = row_tree - .fetch_with_context_at(node_ctx.right.as_ref().unwrap(), primary as Epoch) - .await? - .expect("left is checked to be Some"); - // the value of the predecessor in this case is `payload.max`, since the predecessor is the - // maximum of the subtree rooted in the left child - if payload.max() != value { - // the value of predecessor is different from `value`, so we don't return the - // predecessor node - return Ok(None); - } - // find predecessor in the subtree rooted in the left child: it is - // the rightmost node in such a subtree - let mut predecessor_ctx = left_child_ctx; - while predecessor_ctx.right.is_some() { - predecessor_ctx = row_tree - .node_context_at(predecessor_ctx.right.as_ref().unwrap(), primary as Epoch) - .await? - .unwrap_or_else(|| { - panic!( - "Node context not found for right child of node {:?}", - predecessor_ctx.node_id - ) - }); - } - Ok(Some(predecessor_ctx)) - } else { - // find successor among the ancestors of current node: we go up in the path - // until we either found a node whose right child is the previous node in the - // path, or we get to the root of the tree - let (mut candidate_predecessor_ctx, mut candidate_predecessor_val) = - (node_ctx.clone(), value); - let mut predecessor_found = false; - while candidate_predecessor_ctx.parent.is_some() { - let (parent_ctx, parent_payload) = row_tree - .fetch_with_context_at( - candidate_predecessor_ctx.parent.as_ref().unwrap(), - primary as Epoch, - ) - .await? - .unwrap(); - candidate_predecessor_val = parent_payload.value(); - if parent_ctx - .iter_children() - .find_position(|child| { - child.is_some() && child.unwrap().clone() == candidate_predecessor_ctx.node_id - }) - .unwrap() - .0 - == 1 - { - // predecessor_ctx.node_id is right child of parent_ctx node, so parent_ctx is - // the predecessor - candidate_predecessor_ctx = parent_ctx; - predecessor_found = true; - break; +) -> Option> { + row_tree + .get_predecessor(node_ctx, primary as Epoch) + .await + .and_then(|(predecessor_ctx, predecessor_payload)| { + if predecessor_payload.value() != value { + // the value of successor is different from `value`, so we don't return the + // successor node + None } else { - candidate_predecessor_ctx = parent_ctx; + Some(predecessor_ctx) } - } - if predecessor_found { - if candidate_predecessor_val != value { - // the value of predecessor is different from `value`, so we don't return the - // predecessor node - return Ok(None); - } - Ok(Some(candidate_predecessor_ctx)) - } else { - // We got up to the root of the tree without finding the predecessor, - // which means that the input node has no predecessor; - // so we don't return any node - Ok(None) - } - } + }) } async fn find_node_for_proof( @@ -302,11 +501,11 @@ async fn find_node_for_proof( // from the value `value` stored in the node with key `row_key`; the node found is the one to be // employed to generate the non-existence proof let mut successor_ctx = - get_successor_node_with_same_value(row_tree, &node_ctx, value, primary).await?; + get_successor_node_with_same_value(row_tree, &node_ctx, value, primary).await; while successor_ctx.is_some() { node_ctx = successor_ctx.unwrap(); successor_ctx = - get_successor_node_with_same_value(row_tree, &node_ctx, value, primary).await?; + get_successor_node_with_same_value(row_tree, &node_ctx, value, primary).await; } } else { // starting from the node with key `row_key`, we iterate over its predecessor nodes in the tree, @@ -314,11 +513,11 @@ async fn find_node_for_proof( // from the value `value` stored in the node with key `row_key`; the node found is the one to be // employed to generate the non-existence proof let mut predecessor_ctx = - get_predecessor_node_with_same_value(row_tree, &node_ctx, value, primary).await?; + get_predecessor_node_with_same_value(row_tree, &node_ctx, value, primary).await; while predecessor_ctx.is_some() { node_ctx = predecessor_ctx.unwrap(); predecessor_ctx = - get_predecessor_node_with_same_value(row_tree, &node_ctx, value, primary).await?; + get_predecessor_node_with_same_value(row_tree, &node_ctx, value, primary).await; } } @@ -362,70 +561,67 @@ pub async fn execute_row_query( Ok(rows) } -pub async fn get_node_info( - lookup: &MerkleTreeKvDb, - k: &T::Key, +async fn get_node_info_from_ctx_and_payload< + K: Debug + Clone + Eq + PartialEq, + V: LagrangeNode, + T: TreeFetcher, +>( + tree: &T, + node_ctx: NodeContext, + node_payload: V, at: Epoch, -) -> anyhow::Result<(NodeInfo, Option, Option)> -where - T: TreeTopology + MutableTree + Send, - V: NodePayload + Send + Sync + LagrangeNode, - S: TransactionalStorage + TreeStorage + PayloadStorage + FromSettings, - T::Key: Debug, -{ - // look at the left child first then right child, then build the node info - let (ctx, node_payload) = lookup - .try_fetch_with_context_at(k, at) - .await? - .expect("cache not filled"); +) -> (NodeInfo, Option, Option) { // this looks at the value of a child node (left and right), and fetches the grandchildren // information to be able to build their respective node info. - let fetch_ni = - async |k: Option| -> anyhow::Result<(Option, Option)> { - Ok(match k { - None => (None, None), - Some(child_k) => { - let (child_ctx, child_payload) = lookup - .try_fetch_with_context_at(&child_k, at) - .await? - .expect("cache not filled"); - // we need the grand child hashes for constructing the node info of the - // children of the node in argument - let child_left_hash = match child_ctx.left { - Some(left_left_k) => { - let (_, payload) = lookup - .try_fetch_with_context_at(&left_left_k, at) - .await? - .expect("cache not filled"); - Some(payload.hash()) - } - None => None, - }; - let child_right_hash = match child_ctx.right { - Some(left_right_k) => { - let (_, payload) = lookup - .try_fetch_with_context_at(&left_right_k, at) - .await? - .expect("cache not full"); - Some(payload.hash()) - } - None => None, - }; - let left_ni = NodeInfo::new( - &child_payload.embedded_hash(), - child_left_hash.as_ref(), - child_right_hash.as_ref(), - child_payload.value(), - child_payload.min(), - child_payload.max(), - ); - (Some(left_ni), Some(child_payload.hash())) - } - }) - }; - let (left_node, left_hash) = fetch_ni(ctx.left).await?; - let (right_node, right_hash) = fetch_ni(ctx.right).await?; - Ok(( + let fetch_ni = async |k: Option| -> (Option, Option) { + match k { + None => (None, None), + Some(child_k) => { + let (child_ctx, child_payload) = tree + .fetch_ctx_and_payload_at(&child_k, at) + .await + .unwrap_or_else(|| panic!("key {:?} not found in the tree", child_k)); + // we need the grand child hashes for constructing the node info of the + // children of the node in argument + let child_left_hash = match child_ctx.left { + Some(left_left_k) => { + let (_, payload) = tree + .fetch_ctx_and_payload_at(&left_left_k, at) + .await + .unwrap_or_else(|| { + panic!("key {:?} not found in the tree", left_left_k) + }); + Some(payload.hash()) + } + None => None, + }; + let child_right_hash = match child_ctx.right { + Some(left_right_k) => { + let (_, payload) = tree + .fetch_ctx_and_payload_at(&left_right_k, at) + .await + .unwrap_or_else(|| { + panic!("key {:?} not found in the tree", left_right_k) + }); + Some(payload.hash()) + } + None => None, + }; + let left_ni = NodeInfo::new( + &child_payload.embedded_hash(), + child_left_hash.as_ref(), + child_right_hash.as_ref(), + child_payload.value(), + child_payload.min(), + child_payload.max(), + ); + (Some(left_ni), Some(child_payload.hash())) + } + } + }; + let (left_node, left_hash) = fetch_ni(node_ctx.left).await; + let (right_node, right_hash) = fetch_ni(node_ctx.right).await; + ( NodeInfo::new( &node_payload.embedded_hash(), left_hash.as_ref(), @@ -436,5 +632,21 @@ where ), left_node, right_node, - )) + ) +} + +pub async fn get_node_info< + K: Debug + Clone + Eq + PartialEq, + V: LagrangeNode, + T: TreeFetcher, +>( + tree: &T, + k: &K, + at: Epoch, +) -> (NodeInfo, Option, Option) { + let (node_ctx, node_payload) = tree + .fetch_ctx_and_payload_at(k, at) + .await + .unwrap_or_else(|| panic!("key {:?} not found in the tree", k)); + get_node_info_from_ctx_and_payload(tree, node_ctx, node_payload, at).await } diff --git a/mp2-v1/tests/common/cases/mod.rs b/mp2-v1/tests/common/cases/mod.rs index c6445467e..991c2eef8 100644 --- a/mp2-v1/tests/common/cases/mod.rs +++ b/mp2-v1/tests/common/cases/mod.rs @@ -11,7 +11,6 @@ use super::table::Table; pub mod contract; pub mod indexing; -pub mod planner; pub mod query; pub mod table_source; diff --git a/mp2-v1/tests/common/cases/planner.rs b/mp2-v1/tests/common/cases/planner.rs deleted file mode 100644 index 7a45d551b..000000000 --- a/mp2-v1/tests/common/cases/planner.rs +++ /dev/null @@ -1,423 +0,0 @@ -use std::{collections::HashSet, future::Future}; - -use log::info; -use mp2_v1::indexing::{ - block::BlockPrimaryIndex, - index::IndexNode, - row::{RowPayload, RowTreeKey}, -}; -use parsil::{assembler::DynamicCircuitPis, ParsilSettings}; -use ryhope::{storage::WideLineage, tree::NodeContext, Epoch}; - -use crate::common::{ - cases::query::aggregated_queries::prove_non_existence_row, - index_tree::MerkleIndexTree, - proof_storage::{PlaceholderValues, ProofKey, ProofStorage, QueryID}, - rowtree::MerkleRowTree, - table::{Table, TableColumns}, - TestContext, -}; - -use super::query::{aggregated_queries::prove_single_row, QueryCooking}; - -pub(crate) struct QueryPlanner<'a> { - pub(crate) query: QueryCooking, - pub(crate) pis: &'a DynamicCircuitPis, - pub(crate) ctx: &'a mut TestContext, - pub(crate) settings: &'a ParsilSettings<&'a Table>, - // useful for non existence since we need to search in both trees the places to prove - // the fact a given node doesn't exist - pub(crate) table: &'a Table, - pub(crate) columns: TableColumns, -} - -pub trait TreeInfo { - fn is_row_tree(&self) -> bool; - fn is_satisfying_query(&self, k: &K) -> bool; - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &K, - placeholder_values: PlaceholderValues, - ) -> anyhow::Result>; - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &K, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> anyhow::Result<()>; - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &K, - v: &V, - ) -> anyhow::Result>>; - - fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &K, - ) -> impl Future, V)>> + Send; -} - -impl TreeInfo> - for WideLineage> -{ - fn is_row_tree(&self) -> bool { - true - } - - fn is_satisfying_query(&self, k: &RowTreeKey) -> bool { - self.is_touched_key(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - ) -> anyhow::Result> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> anyhow::Result<()> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &RowTreeKey, - _v: &RowPayload, - ) -> anyhow::Result>> { - // TODO export that in single function - Ok(if self.is_satisfying_query(k) { - let ctx = &mut planner.ctx; - Some( - prove_single_row( - ctx, - self, - &planner.columns, - primary, - k, - planner.pis, - &planner.query, - ) - .await?, - ) - } else { - None - }) - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &RowTreeKey, - ) -> Option<(NodeContext, RowPayload)> { - self.ctx_and_payload_at(epoch, key) - } -} - -pub struct RowInfo<'a> { - pub(crate) satisfiying_rows: HashSet, - pub(crate) tree: &'a MerkleRowTree, -} - -impl<'a> RowInfo<'a> { - pub fn no_satisfying_rows(tree: &'a MerkleRowTree) -> Self { - Self { - satisfiying_rows: Default::default(), - tree, - } - } -} - -impl TreeInfo> for RowInfo<'_> { - fn is_row_tree(&self) -> bool { - true - } - - fn is_satisfying_query(&self, k: &RowTreeKey) -> bool { - self.satisfiying_rows.contains(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - ) -> anyhow::Result> { - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &RowTreeKey, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> anyhow::Result<()> { - let proof_key = ProofKey::QueryAggregateRow(( - query_id.clone(), - placeholder_values, - primary, - key.clone(), - )); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &RowTreeKey, - _v: &RowPayload, - ) -> anyhow::Result>> { - Ok(if self.is_satisfying_query(k) { - let ctx = &mut planner.ctx; - Some( - prove_single_row( - ctx, - self, - &planner.columns, - primary, - k, - planner.pis, - &planner.query, - ) - .await?, - ) - } else { - None - }) - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &RowTreeKey, - ) -> Option<(NodeContext, RowPayload)> { - self.tree - .try_fetch_with_context_at(key, epoch) - .await - .unwrap() - } -} - -impl TreeInfo> - for WideLineage> -{ - fn is_row_tree(&self) -> bool { - false - } - - fn is_satisfying_query(&self, k: &BlockPrimaryIndex) -> bool { - self.is_touched_key(k) - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - ) -> anyhow::Result> { - // TODO export that in single function - repetition - info!("loading proof for {primary} -> {key:?}"); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - _primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> anyhow::Result<()> { - // TODO export that in single function - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, - ) -> anyhow::Result>> { - load_or_prove_embedded_index(self, planner, primary, k, v).await - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &BlockPrimaryIndex, - ) -> Option<(NodeContext, IndexNode)> { - self.ctx_and_payload_at(epoch, key) - } -} - -pub struct IndexInfo<'a> { - pub(crate) bounds: (BlockPrimaryIndex, BlockPrimaryIndex), - pub(crate) tree: &'a MerkleIndexTree, -} - -impl<'a> IndexInfo<'a> { - pub fn non_satisfying_info(tree: &'a MerkleIndexTree) -> Self { - Self { - // so it never returns true to is satisfying query - bounds: (BlockPrimaryIndex::MAX, BlockPrimaryIndex::MIN), - tree, - } - } -} - -impl TreeInfo> for IndexInfo<'_> { - fn is_row_tree(&self) -> bool { - false - } - - fn is_satisfying_query(&self, k: &BlockPrimaryIndex) -> bool { - self.bounds.0 <= *k && *k <= self.bounds.1 - } - - fn load_proof( - &self, - ctx: &TestContext, - query_id: &QueryID, - primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - ) -> anyhow::Result> { - //assert_eq!(primary, *key); - info!("loading proof for {primary} -> {key:?}"); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.get_proof_exact(&proof_key) - } - - fn save_proof( - &self, - ctx: &mut TestContext, - query_id: &QueryID, - _primary: BlockPrimaryIndex, - key: &BlockPrimaryIndex, - placeholder_values: PlaceholderValues, - proof: Vec, - ) -> anyhow::Result<()> { - //assert_eq!(primary, *key); - let proof_key = ProofKey::QueryAggregateIndex((query_id.clone(), placeholder_values, *key)); - ctx.storage.store_proof(proof_key, proof) - } - - async fn load_or_prove_embedded( - &self, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, - ) -> anyhow::Result>> { - load_or_prove_embedded_index(self, planner, primary, k, v).await - } - - async fn fetch_ctx_and_payload_at( - &self, - epoch: Epoch, - key: &BlockPrimaryIndex, - ) -> Option<(NodeContext, IndexNode)> { - self.tree - .try_fetch_with_context_at(key, epoch) - .await - .unwrap() - } -} - -async fn load_or_prove_embedded_index< - T: TreeInfo>, ->( - info: &T, - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, - k: &BlockPrimaryIndex, - v: &IndexNode, -) -> anyhow::Result>> { - //assert_eq!(primary, *k); - info!("loading embedded proof for node {primary} -> {k:?}"); - Ok(if info.is_satisfying_query(k) { - // load the proof of the row root for this query, if it is already proven; - // otherwise, it means that there are no rows in the rows tree embedded in this - // node that satisfies the query bounds on secondary index, so we need to - // generate a non-existence proof for the row tree - let row_root_proof_key = ProofKey::QueryAggregateRow(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - *k, - v.row_tree_root_key.clone(), - )); - let proof = match planner.ctx.storage.get_proof_exact(&row_root_proof_key) { - Ok(proof) => proof, - Err(_) => { - prove_non_existence_row(planner, *k).await?; - info!("non existence proved for {primary} -> {k:?}"); - // fetch again the generated proof - planner - .ctx - .storage - .get_proof_exact(&row_root_proof_key) - .unwrap_or_else(|_| { - panic!("non-existence root proof not found for key {row_root_proof_key:?}") - }) - } - }; - Some(proof) - } else { - None - }) -} diff --git a/mp2-v1/tests/common/cases/query/aggregated_queries.rs b/mp2-v1/tests/common/cases/query/aggregated_queries.rs index 31ab18d00..7af10123b 100644 --- a/mp2-v1/tests/common/cases/query/aggregated_queries.rs +++ b/mp2-v1/tests/common/cases/query/aggregated_queries.rs @@ -1,28 +1,23 @@ use plonky2::{ field::types::PrimeField64, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; -use std::{ - collections::{HashMap, HashSet}, - fmt::Debug, - hash::Hash, -}; +use std::collections::HashMap; use crate::common::{ cases::{ indexing::BLOCK_COLUMN_NAME, - planner::{IndexInfo, QueryPlanner, RowInfo, TreeInfo}, - query::{QueryCooking, SqlReturn, SqlType}, + query::{QueryCooking, SqlReturn, SqlType, NUM_CHUNKS, NUM_ROWS}, table_source::BASE_VALUE, }, proof_storage::{ProofKey, ProofStorage}, rowtree::MerkleRowTree, - table::{Table, TableColumns}, + table::Table, TableInfo, }; use crate::context::TestContext; use alloy::primitives::U256; -use anyhow::bail; +use anyhow::{bail, Result}; use futures::{stream, FutureExt, StreamExt}; use itertools::Itertools; @@ -40,146 +35,133 @@ use mp2_v1::{ block::BlockPrimaryIndex, cell::MerkleCell, row::{Row, RowPayload, RowTreeKey}, - LagrangeNode, }, - query::planner::{execute_row_query, find_row_node_for_non_existence}, - values_extraction::identifier_block_column, + query::{ + batching_planner::{generate_chunks_and_update_tree, UTForChunkProofs, UTKey}, + planner::{execute_row_query, NonExistenceInput, TreeFetcher}, + }, }; use parsil::{ assembler::{DynamicCircuitPis, StaticCircuitPis}, queries::{core_keys_for_index_tree, core_keys_for_row_tree}, - ParsilSettings, DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, + DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, }; use ryhope::{ storage::{ - updatetree::{Next, UpdateTree, WorkplanItem}, + updatetree::{Next, WorkplanItem}, EpochKvStorage, RoEpochKvStorage, TreeTransactionalStorage, }, - tree::NodeContext, - Epoch, NodePayload, + Epoch, }; use sqlparser::ast::Query; use tokio_postgres::Row as PsqlRow; use verifiable_db::{ ivc::PublicInputs as IndexingPIS, query::{ - aggregation::{ChildPosition, NodeInfo, QueryHashNonExistenceCircuits, SubProof}, computational_hash_ids::{ColumnIDs, Identifiers}, - universal_circuit::universal_circuit_inputs::{ - ColumnCell, PlaceholderId, Placeholders, RowCells, - }, + universal_circuit::universal_circuit_inputs::{ColumnCell, PlaceholderId, Placeholders}, }, revelation::PublicInputs, }; use super::{ - GlobalCircuitInput, QueryCircuitInput, RevelationCircuitInput, MAX_NUM_COLUMNS, - MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, + GlobalCircuitInput, QueryCircuitInput, QueryPlanner, RevelationCircuitInput, + MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, MAX_NUM_PLACEHOLDERS, }; pub type RevelationPublicInputs<'a> = PublicInputs<'a, F, MAX_NUM_OUTPUTS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_PLACEHOLDERS>; /// Execute a query to know all the touched rows, and then call the universal circuit on all rows -#[allow(clippy::too_many_arguments)] pub(crate) async fn prove_query( - ctx: &mut TestContext, - table: &Table, - query: QueryCooking, mut parsed: Query, - settings: &ParsilSettings<&Table>, res: Vec, metadata: MetadataHash, - pis: DynamicCircuitPis, -) -> anyhow::Result<()> { - let row_cache = table + planner: &mut QueryPlanner<'_>, +) -> Result<()> { + let row_cache = planner + .table .row .wide_lineage_between( - table.row.current_epoch(), - &core_keys_for_row_tree(&query.query, settings, &pis.bounds, &query.placeholders)?, - (query.min_block as Epoch, query.max_block as Epoch), + planner.table.row.current_epoch(), + &core_keys_for_row_tree( + &planner.query.query, + planner.settings, + &planner.pis.bounds, + &planner.query.placeholders, + )?, + ( + planner.query.min_block as Epoch, + planner.query.max_block as Epoch, + ), ) .await?; - // the query to use to fetch all the rows keys involved in the result tree. - let pis = parsil::assembler::assemble_dynamic(&parsed, settings, &query.placeholders)?; - let row_keys_per_epoch = row_cache.keys_by_epochs(); - let mut planner = QueryPlanner { - ctx, - query: query.clone(), - settings, - pis: &pis, - table, - columns: table.columns.clone(), - }; - - // prove the different versions of the row tree for each of the involved rows for each block - for (epoch, keys) in row_keys_per_epoch { - let up = row_cache - .update_tree_for(epoch as Epoch) - .expect("this epoch should exist"); - let info = RowInfo { - tree: &table.row, - satisfiying_rows: keys, - }; - prove_query_on_tree(&mut planner, info, up, epoch as BlockPrimaryIndex).await?; - } - // prove the index tree, on a single version. Both path can be taken depending if we do have // some nodes or not - let initial_epoch = table.index.initial_epoch() as BlockPrimaryIndex; - let current_epoch = table.index.current_epoch() as BlockPrimaryIndex; - let block_range = query.min_block.max(initial_epoch + 1)..=query.max_block.min(current_epoch); + let initial_epoch = planner.table.index.initial_epoch() as BlockPrimaryIndex; + let current_epoch = planner.table.index.current_epoch() as BlockPrimaryIndex; + let block_range = + planner.query.min_block.max(initial_epoch + 1)..=planner.query.max_block.min(current_epoch); info!( "found {} blocks in range: {:?}", block_range.clone().count(), block_range ); - if block_range.is_empty() { + let column_ids = ColumnIDs::from(&planner.table.columns); + let query_proof_id = if block_range.is_empty() { info!("Running INDEX TREE proving for EMPTY query"); // no valid blocks in the query range, so we need to choose a block to prove // non-existence. Either the one after genesis or the last one - let to_be_proven_node = if query.max_block < initial_epoch { + let to_be_proven_node = if planner.query.max_block < initial_epoch { initial_epoch + 1 - } else if query.min_block > current_epoch { + } else if planner.query.min_block > current_epoch { current_epoch } else { bail!( "Empty block range to be proven for query bounds {}, {}, but no node to be proven with non-existence circuit was found. Something is wrong", - query.min_block, - query.max_block + planner.query.min_block, + planner.query.max_block ); } as BlockPrimaryIndex; - prove_non_existence_index(&mut planner, to_be_proven_node).await?; - // we get the lineage of the node that proves the non existence of the index nodes - // required for the query. We specify the epoch at which we want to get this lineage as - // of the current epoch. - let lineage = table + let index_path = planner + .table .index - .lineage_at(&to_be_proven_node, current_epoch as Epoch) - .await? - .expect("can't get lineage") - .into_full_path() - .collect(); - let up = UpdateTree::from_path(lineage, current_epoch as Epoch); - let info = IndexInfo { - tree: &table.index, - bounds: (query.min_block, query.max_block), - }; - prove_query_on_tree( - &mut planner, - info, - up, - table.index.current_epoch() as BlockPrimaryIndex, - ) - .await?; + .compute_path(&to_be_proven_node, current_epoch as Epoch) + .await + .unwrap_or_else(|| { + panic!("Compute path for index node with key {to_be_proven_node} failed") + }); + let input = QueryCircuitInput::new_non_existence_input( + index_path, + &column_ids, + &planner.pis.predication_operations, + &planner.pis.result, + &planner.query.placeholders, + &planner.pis.bounds, + )?; + let query_proof = planner + .ctx + .run_query_proof("batching::non_existence", GlobalCircuitInput::Query(input))?; + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + UTKey::default(), + )); + planner + .ctx + .storage + .store_proof(proof_key.clone(), query_proof)?; + proof_key } else { info!("Running INDEX tree proving from cache"); // Only here we can run the SQL query for index so it doesn't crash - let index_query = - core_keys_for_index_tree(current_epoch as Epoch, (query.min_block, query.max_block))?; - let big_index_cache = table + let index_query = core_keys_for_index_tree( + current_epoch as Epoch, + (planner.query.min_block, planner.query.max_block), + )?; + let big_index_cache = planner + .table .index // The bounds here means between which versions of the tree should we look. For index tree, // we only look at _one_ version of the tree. @@ -189,36 +171,101 @@ pub(crate) async fn prove_query( (current_epoch as Epoch, current_epoch as Epoch), ) .await?; - // since we only analyze the index tree for one epoch - assert_eq!(big_index_cache.keys_by_epochs().len(), 1); - // This is ok because the cache only have the block that are in the range so the - // filter_check is gonna return the same thing - // TOOD: @franklin is that correct ? - let up = big_index_cache - // this is the epoch we choose how to prove - .update_tree_for(current_epoch as Epoch) - .expect("this epoch should exist"); - prove_query_on_tree( - &mut planner, - big_index_cache, - up, - table.index.current_epoch() as BlockPrimaryIndex, - ) - .await?; - } + let (proven_chunks, update_tree) = + generate_chunks_and_update_tree::( + row_cache, + big_index_cache, + &column_ids, + NonExistenceInput::new( + &planner.table.row, + planner.table.public_name.clone(), + &planner.table.db_pool, + planner.settings, + &planner.pis.bounds, + ), + current_epoch as Epoch, + ) + .await?; + info!("Root of update tree is {:?}", update_tree.root()); + let mut workplan = update_tree.into_workplan(); + let mut proof_id = None; + while let Some(Next::Ready(wk)) = workplan.next() { + let (k, is_path_end) = if let WorkplanItem::Node { k, is_path_end } = &wk { + (k, *is_path_end) + } else { + unreachable!("this update tree has been created with a batch size of 1") + }; + let proof = if is_path_end { + // this is a row chunk to be proven + let to_be_proven_chunk = proven_chunks + .get(k) + .unwrap_or_else(|| panic!("chunk for key {:?} not found", k)); + let input = QueryCircuitInput::new_row_chunks_input( + to_be_proven_chunk, + &planner.pis.predication_operations, + &planner.query.placeholders, + &planner.pis.bounds, + &planner.pis.result, + )?; + info!("Proving chunk {:?}", k); + planner.ctx.run_query_proof( + "batching::chunk_processing", + GlobalCircuitInput::Query(input), + ) + } else { + let children_keys = workplan.t.get_children_keys(k); + info!("children keys: {:?}", children_keys); + // fetch the proof for each child from the storage + let child_proofs = children_keys + .into_iter() + .map(|child_key| { + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + child_key, + )); + planner.ctx.storage.get_proof_exact(&proof_key) + }) + .collect::>>()?; + let input = QueryCircuitInput::new_chunk_aggregation_input(&child_proofs)?; + info!("Aggregating chunk {:?}", k); + planner.ctx.run_query_proof( + "batching::chunk_aggregation", + GlobalCircuitInput::Query(input), + ) + }?; + let proof_key = ProofKey::QueryAggregate(( + planner.query.query.clone(), + planner.query.placeholders.placeholder_values(), + *k, + )); + planner.ctx.storage.store_proof(proof_key.clone(), proof)?; + proof_id = Some(proof_key); + workplan.done(&wk)?; + } + proof_id.unwrap() + }; + + info!("proving revelation"); - info!("Query proofs done! Generating revelation proof..."); - let proof = prove_revelation(ctx, table, &query, &pis, table.index.current_epoch()).await?; + let proof = prove_revelation( + planner.ctx, + &planner.query, + planner.pis, + planner.table.index.current_epoch(), + &query_proof_id, + ) + .await?; info!("Revelation proof done! Checking public inputs..."); // get `StaticPublicInputs`, i.e., the data about the query available only at query registration time, // to check the public inputs - let pis = parsil::assembler::assemble_static(&parsed, settings)?; + let pis = parsil::assembler::assemble_static(&parsed, planner.settings)?; // get number of matching rows - let mut exec_query = parsil::executor::generate_query_keys(&mut parsed, settings)?; - let query_params = exec_query.convert_placeholders(&query.placeholders); + let mut exec_query = parsil::executor::generate_query_keys(&mut parsed, planner.settings)?; + let query_params = exec_query.convert_placeholders(&planner.query.placeholders); let num_touched_rows = execute_row_query( - &table.db_pool, + &planner.table.db_pool, &exec_query .normalize_placeholder_names() .to_pgsql_string_with_placeholder(), @@ -229,11 +276,11 @@ pub(crate) async fn prove_query( check_final_outputs( proof, - ctx, - table, - &query, + planner.ctx, + planner.table, + &planner.query, &pis, - table.index.current_epoch(), + planner.table.index.current_epoch(), num_touched_rows, res, metadata, @@ -244,21 +291,13 @@ pub(crate) async fn prove_query( async fn prove_revelation( ctx: &TestContext, - table: &Table, query: &QueryCooking, pis: &DynamicCircuitPis, tree_epoch: Epoch, -) -> anyhow::Result> { + query_proof_id: &ProofKey, +) -> Result> { // load the query proof, which is at the root of the tree - let query_proof = { - let root_key = table.index.root_at(tree_epoch).await?.unwrap(); - let proof_key = ProofKey::QueryAggregateIndex(( - query.query.clone(), - query.placeholders.placeholder_values(), - root_key, - )); - ctx.storage.get_proof_exact(&proof_key)? - }; + let query_proof = ctx.storage.get_proof_exact(query_proof_id)?; // load the preprocessing proof at the same epoch let indexing_proof = { let pk = ProofKey::IVC(tree_epoch as BlockPrimaryIndex); @@ -290,7 +329,7 @@ pub(crate) fn check_final_outputs( num_touched_rows: usize, res: Vec, offcircuit_md: MetadataHash, -) -> anyhow::Result<()> { +) -> Result<()> { // fetch indexing proof, whose public inputs are needed to check correctness of revelation proof outputs let indexing_proof = { let pk = ProofKey::IVC(tree_epoch as BlockPrimaryIndex); @@ -383,547 +422,12 @@ pub(crate) fn check_final_outputs( Ok(()) } -/// Generic function as to how to handle the aggregation. It handles both aggregation in the row -/// tree as well as in the index tree the same way. The TreeInfo trait is just here to bring some -/// context, so savign and loading the proof at the right location depending if it's a row or index -/// tree -/// clippy doesn't see that it can not be done -#[allow(clippy::needless_lifetimes)] -async fn prove_query_on_tree<'a, I, K, V>( - planner: &mut QueryPlanner<'a>, - info: I, - update: UpdateTree, - primary: BlockPrimaryIndex, -) -> anyhow::Result> -where - I: TreeInfo, - V: NodePayload + Send + Sync + LagrangeNode + Clone, - K: Debug + Hash + Clone + Eq + Sync + Send, -{ - let query_id = planner.query.query.clone(); - let placeholder_values = planner.query.placeholders.placeholder_values(); - let mut workplan = update.into_workplan(); - let mut proven_nodes = HashSet::new(); - let fetch_only_proven_child = |nctx: NodeContext, - cctx: &TestContext, - proven: &HashSet| - -> (ChildPosition, Vec) { - let (child_key, pos) = match (nctx.left, nctx.right) { - (Some(left), Some(right)) => { - assert!( - proven.contains(&left) ^ proven.contains(&right), - "only one child should be already proven, not both" - ); - if proven.contains(&left) { - (left, ChildPosition::Left) - } else { - (right, ChildPosition::Right) - } - } - (Some(left), None) if proven.contains(&left) => (left, ChildPosition::Left), - (None, Some(right)) if proven.contains(&right) => (right, ChildPosition::Right), - _ => panic!("stg's wrong in the tree"), - }; - let child_proof = info - .load_proof( - cctx, - &query_id, - primary, - &child_key, - placeholder_values.clone(), - ) - .expect("key should already been proven"); - (pos, child_proof) - }; - while let Some(Next::Ready(wk)) = workplan.next() { - let k = wk.k(); - // closure performing all the operations necessary beofre jumping to the next iteration - let mut end_iteration = |proven_nodes: &mut HashSet| -> anyhow::Result<()> { - proven_nodes.insert(k.clone()); - workplan.done(&wk)?; - Ok(()) - }; - // since epoch starts at genesis now, we can directly give the value of the block - // number as epoch number - let (node_ctx, node_payload) = info - .fetch_ctx_and_payload_at(primary as Epoch, k) - .await - .expect("cache is not full"); - let is_satisfying_query = info.is_satisfying_query(k); - let embedded_proof = info - .load_or_prove_embedded(planner, primary, k, &node_payload) - .await; - if node_ctx.is_leaf() && info.is_row_tree() { - // NOTE: if it is a leaf of the row tree, then there is no need to prove anything, - // since we're not "aggregating" any from below. For the index tree however, we - // need to always generate an aggregate proof. Therefore, in this test, we just copy the - // proof to the expected aggregation location and move on. Note that we need to - // save the proof only if the current row is satisfying the query: indeed, if - // this not the case, then the proof should have already been generated and stored - // with the non-existence circuit - if is_satisfying_query { - // unwrap is safe since we are guaranteed the row is satisfying the query - info.save_proof( - planner.ctx, - &query_id, - primary, - k, - placeholder_values.clone(), - embedded_proof?.unwrap(), - )?; - } - - end_iteration(&mut proven_nodes)?; - continue; - } - - // In the case we haven't proven anything under this node, it's the single path case - // It is sufficient to check if this node is one of the leaves we in this update tree.Note - // it is not the same meaning as a "leaf of a tree", here it just means is it the first - // node in the merkle path. - let (k, is_path_end) = if let WorkplanItem::Node { k, is_path_end } = &wk { - (k, *is_path_end) - } else { - unreachable!("this update tree has been created with a batch size of 1") - }; - - let (name, input) = if is_path_end { - info!("node {primary} -> {k:?} is at path end"); - if !is_satisfying_query { - // if the node of the key does not satisfy the query, but this node is at the end of - // a path to be proven, then it means we are in a tree with no satisfying nodes, and - // so the current node is the node chosen to be proven with non-existence circuits. - // Since the node has already been proven, we just save the proof and we continue - end_iteration(&mut proven_nodes)?; - continue; - } - assert!( - info.is_satisfying_query(k), - "first node in merkle path should always be a valid query one" - ); - let (node_info, left_info, right_info) = - // we can use primary as epoch now that tree stores epoch from genesis - get_node_info(&info, k, primary as Epoch).await; - ( - "querying::aggregation::single", - QueryCircuitInput::new_single_path( - SubProof::new_embedded_tree_proof(embedded_proof?.unwrap())?, - left_info, - right_info, - node_info, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create leaf input"), - ) - } else { - // here we are guaranteed there is a node below that we have already proven - // It can not be a single path with the embedded tree only since that falls into the - // previous category ("is_path_end" == true) since update plan starts by the "leaves" - // of all the paths it has been given. - // So it means There is at least one child of this node that we have proven before. - // If this node is satisfying query, then we use One/TwoProvenChildNode, - // If this node is not in the query touched rows, we use a SinglePath with proven child tree. - // - if !is_satisfying_query { - let (child_pos, child_proof) = - fetch_only_proven_child(node_ctx, planner.ctx, &proven_nodes); - let (node_info, left_info, right_info) = get_node_info( - &info, - k, - // we can use primary as epoch since storage starts epoch at genesis - primary as Epoch, - ) - .await; - // we look which child is the one to load from storage, the one we already proved - ( - "querying::aggregation::single", - QueryCircuitInput::new_single_path( - SubProof::new_child_proof(child_proof, child_pos)?, - left_info, - right_info, - node_info, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create leaf input"), - ) - } else { - // this case is easy, since all that's left is partial or full where both - // child(ren) and current node belong to query - let is_correct_left = node_ctx.left.is_some() - && proven_nodes.contains(node_ctx.left.as_ref().unwrap()); - let is_correct_right = node_ctx.right.is_some() - && proven_nodes.contains(node_ctx.right.as_ref().unwrap()); - if is_correct_left && is_correct_right { - // full node case - let left_proof = info.load_proof( - planner.ctx, - &query_id, - primary, - node_ctx.left.as_ref().unwrap(), - placeholder_values.clone(), - )?; - let right_proof = info.load_proof( - planner.ctx, - &query_id, - primary, - node_ctx.right.as_ref().unwrap(), - placeholder_values.clone(), - )?; - ( - "querying::aggregation::full", - QueryCircuitInput::new_full_node( - left_proof, - right_proof, - embedded_proof?.expect("should be a embedded_proof here"), - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't create full node circuit input"), - ) - } else { - // partial case - let (child_pos, child_proof) = - fetch_only_proven_child(node_ctx, planner.ctx, &proven_nodes); - let (_, left_info, right_info) = - get_node_info(&info, k, primary as Epoch).await; - let unproven = match child_pos { - ChildPosition::Left => right_info, - ChildPosition::Right => left_info, - }; - ( - "querying::aggregation::partial", - QueryCircuitInput::new_partial_node( - child_proof, - embedded_proof?.expect("should be an embedded_proof here too"), - unproven, - child_pos, - info.is_row_tree(), - &planner.pis.bounds, - ) - .expect("can't build new partial node input"), - ) - } - } - }; - info!("AGGREGATE query proof RUNNING for {primary} -> {k:?} "); - let proof = planner - .ctx - .run_query_proof(name, GlobalCircuitInput::Query(input))?; - info.save_proof( - planner.ctx, - &query_id, - primary, - k, - placeholder_values.clone(), - proof, - )?; - info!("query proof DONE for {primary} -> {k:?} "); - end_iteration(&mut proven_nodes)?; - } - Ok(vec![]) -} - -// TODO: make it recursive with async - tentative in `fetch_child_info` but it doesn't work, -// recursion with async is weird. -pub(crate) async fn get_node_info>( - lookup: &T, - k: &K, - at: Epoch, -) -> (NodeInfo, Option, Option) -where - K: Debug + Hash + Clone + Send + Sync + Eq, - // NOTICE the ToValue here to get the value associated to a node - V: NodePayload + Send + Sync + LagrangeNode + Clone, -{ - // look at the left child first then right child, then build the node info - let (ctx, node_payload) = lookup - .fetch_ctx_and_payload_at(at, k) - .await - .expect("cache not filled"); - // this looks at the value of a child node (left and right), and fetches the grandchildren - // information to be able to build their respective node info. - let fetch_ni = async |k: Option| -> (Option, Option) { - match k { - None => (None, None), - Some(child_k) => { - let (child_ctx, child_payload) = lookup - .fetch_ctx_and_payload_at(at, &child_k) - .await - .expect("cache not filled"); - // we need the grand child hashes for constructing the node info of the - // children of the node in argument - let child_left_hash = match child_ctx.left { - Some(left_left_k) => { - let (_, payload) = lookup - .fetch_ctx_and_payload_at(at, &left_left_k) - .await - .expect("cache not filled"); - Some(payload.hash()) - } - None => None, - }; - let child_right_hash = match child_ctx.right { - Some(left_right_k) => { - let (_, payload) = lookup - .fetch_ctx_and_payload_at(at, &left_right_k) - .await - .expect("cache not full"); - Some(payload.hash()) - } - None => None, - }; - let left_ni = NodeInfo::new( - &child_payload.embedded_hash(), - child_left_hash.as_ref(), - child_right_hash.as_ref(), - child_payload.value(), - child_payload.min(), - child_payload.max(), - ); - (Some(left_ni), Some(child_payload.hash())) - } - } - }; - let (left_node, left_hash) = fetch_ni(ctx.left).await; - let (right_node, right_hash) = fetch_ni(ctx.right).await; - ( - NodeInfo::new( - &node_payload.embedded_hash(), - left_hash.as_ref(), - right_hash.as_ref(), - node_payload.value(), - node_payload.min(), - node_payload.max(), - ), - left_node, - right_node, - ) -} - -pub fn generate_non_existence_proof( - node_info: NodeInfo, - left_child_info: Option, - right_child_info: Option, - primary: BlockPrimaryIndex, - planner: &mut QueryPlanner<'_>, - is_rows_tree_node: bool, -) -> anyhow::Result> { - let index_ids = [ - planner.table.columns.primary_column().identifier, - planner.table.columns.secondary_column().identifier, - ]; - assert_eq!(index_ids[0], identifier_block_column()); - let column_ids = ColumnIDs::new( - index_ids[0], - index_ids[1], - planner - .table - .columns - .non_indexed_columns() - .iter() - .map(|column| column.identifier) - .collect_vec(), - ); - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >( - &column_ids, - &planner.pis.predication_operations, - &planner.pis.result, - &planner.query.placeholders, - &planner.pis.bounds, - is_rows_tree_node, - )?; - let input = QueryCircuitInput::new_non_existence_input( - node_info, - left_child_info, - right_child_info, - U256::from(primary), - &index_ids, - &planner.pis.query_aggregations, - query_hashes, - is_rows_tree_node, - &planner.pis.bounds, - &planner.query.placeholders, - )?; - planner - .ctx - .run_query_proof("querying::non_existence", GlobalCircuitInput::Query(input)) -} - -/// Generate a proof for a node of the index tree which is outside of the query bounds -async fn prove_non_existence_index( - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, -) -> anyhow::Result<()> { - let tree = &planner.table.index; - let current_epoch = tree.current_epoch(); - let (node_info, left_child_info, right_child_info) = get_node_info( - &IndexInfo::non_satisfying_info(tree), - &primary, - current_epoch, - ) - .await; - let proof_key = ProofKey::QueryAggregateIndex(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - primary, - )); - info!("Non-existence circuit proof RUNNING for {current_epoch} -> {primary} "); - let proof = generate_non_existence_proof( - node_info, - left_child_info, - right_child_info, - primary, - planner, - false, - ) - .unwrap_or_else(|_| { - panic!("unable to generate non-existence proof for {current_epoch} -> {primary}") - }); - info!("Non-existence circuit proof DONE for {current_epoch} -> {primary} "); - planner.ctx.storage.store_proof(proof_key, proof.clone())?; - - Ok(()) -} - -pub async fn prove_non_existence_row( - planner: &mut QueryPlanner<'_>, - primary: BlockPrimaryIndex, -) -> anyhow::Result<()> { - let (chosen_node, plan) = find_row_node_for_non_existence( - &planner.table.row, - planner.table.public_name.clone(), - &planner.table.db_pool, - primary, - planner.settings, - &planner.pis.bounds, - ) - .await?; - let (node_info, left_child_info, right_child_info) = - mp2_v1::query::planner::get_node_info(&planner.table.row, &chosen_node, primary as Epoch) - .await?; - - let proof_key = ProofKey::QueryAggregateRow(( - planner.query.query.clone(), - planner.query.placeholders.placeholder_values(), - primary, - chosen_node.clone(), - )); - info!( - "Non-existence circuit proof RUNNING for {primary} -> {:?} ", - proof_key - ); - let proof = generate_non_existence_proof( - node_info, - left_child_info, - right_child_info, - primary, - planner, - true, - ) - .unwrap_or_else(|_| { - panic!( - "unable to generate non-existence proof for {primary} -> {:?}", - chosen_node - ) - }); - info!( - "Non-existence circuit proof DONE for {primary} -> {:?} ", - chosen_node - ); - planner.ctx.storage.store_proof(proof_key, proof.clone())?; - - let tree_info = RowInfo::no_satisfying_rows(&planner.table.row); - let mut planner = QueryPlanner { - ctx: planner.ctx, - table: planner.table, - query: planner.query.clone(), - pis: planner.pis, - columns: planner.columns.clone(), - settings: planner.settings, - }; - prove_query_on_tree(&mut planner, tree_info, plan, primary).await?; - - Ok(()) -} - -pub async fn prove_single_row>>( - ctx: &mut TestContext, - tree: &T, - columns: &TableColumns, - primary: BlockPrimaryIndex, - row_key: &RowTreeKey, - pis: &DynamicCircuitPis, - query: &QueryCooking, -) -> anyhow::Result> { - // 1. Get the all the cells including primary and secondary index - // Note we can use the primary as epoch since now epoch == primary in the storage - let (row_ctx, row_payload) = tree - .fetch_ctx_and_payload_at(primary as Epoch, row_key) - .await - .expect("cache not full"); - - // API is gonna change on this but right now, we have to sort all the "rest" cells by index - // in the tree, and put the primary one and secondary one in front - let rest_cells = columns - .non_indexed_columns() - .iter() - .map(|tc| tc.identifier) - .filter_map(|id| { - row_payload - .cells - .find_by_column(id) - .map(|info| ColumnCell::new(id, info.value)) - }) - .collect::>(); - - let secondary_cell = ColumnCell::new( - row_payload.secondary_index_column, - row_payload.secondary_index_value(), - ); - let primary_cell = ColumnCell::new(identifier_block_column(), U256::from(primary)); - let row = RowCells::new(primary_cell, secondary_cell, rest_cells); - // 2. create input - let input = QueryCircuitInput::new_universal_circuit( - &row, - &pis.predication_operations, - &pis.result, - &query.placeholders, - row_ctx.is_leaf(), - &pis.bounds, - ) - .expect("unable to create universal query circuit inputs"); - // 3. run proof if not ran already - let proof_key = ProofKey::QueryUniversal(( - query.query.clone(), - query.placeholders.placeholder_values(), - primary, - row_key.clone(), - )); - let proof = { - info!("Universal query proof RUNNING for {primary} -> {row_key:?} "); - let proof = ctx - .run_query_proof("querying::universal", GlobalCircuitInput::Query(input)) - .expect("unable to generate universal proof for {epoch} -> {row_key:?}"); - info!("Universal query proof DONE for {primary} -> {row_key:?} "); - ctx.storage.store_proof(proof_key, proof.clone())?; - proof - }; - Ok(proof) -} - type BlockRange = (BlockPrimaryIndex, BlockPrimaryIndex); pub(crate) async fn cook_query_between_blocks( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let max = table.row.current_epoch(); let min = max - 1; @@ -950,7 +454,7 @@ pub(crate) async fn cook_query_between_blocks( pub(crate) async fn cook_query_secondary_index_nonexisting_placeholder( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let (longest_key, (min_block, max_block)) = find_longest_lived_key(table, false).await?; let key_value = hex::encode(longest_key.value.to_be_bytes_trimmed_vec()); info!( @@ -998,7 +502,7 @@ pub(crate) async fn cook_query_secondary_index_nonexisting_placeholder( pub(crate) async fn cook_query_secondary_index_placeholder( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let (longest_key, (min_block, max_block)) = find_longest_lived_key(table, false).await?; let key_value = hex::encode(longest_key.value.to_be_bytes_trimmed_vec()); info!( @@ -1043,7 +547,7 @@ pub(crate) async fn cook_query_secondary_index_placeholder( pub(crate) async fn cook_query_unique_secondary_index( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let (longest_key, (min_block, max_block)) = find_longest_lived_key(table, false).await?; let key_value = hex::encode(longest_key.value.to_be_bytes_trimmed_vec()); info!( @@ -1119,7 +623,7 @@ pub(crate) async fn cook_query_unique_secondary_index( pub(crate) async fn cook_query_partial_block_range( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let (longest_key, (min_block, max_block)) = find_longest_lived_key(table, false).await?; let key_value = hex::encode(longest_key.value.to_be_bytes_trimmed_vec()); info!( @@ -1155,7 +659,7 @@ pub(crate) async fn cook_query_partial_block_range( pub(crate) async fn cook_query_no_matching_entries( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let initial_epoch = table.row.initial_epoch(); // choose query bounds outside of the range [initial_epoch, last_epoch] let min_block = 0; @@ -1187,7 +691,7 @@ pub(crate) async fn cook_query_no_matching_entries( pub(crate) async fn cook_query_non_matching_entries_some_blocks( table: &Table, info: &TableInfo, -) -> anyhow::Result { +) -> Result { let (longest_key, (min_block, max_block)) = find_longest_lived_key(table, true).await?; let key_value = hex::encode(longest_key.value.to_be_bytes_trimmed_vec()); info!( @@ -1223,7 +727,7 @@ pub(crate) async fn cook_query_non_matching_entries_some_blocks( /// Utility function to associated to each row in the tree, the blocks where the row /// was valid -async fn extract_row_liveness(table: &Table) -> anyhow::Result>> { +async fn extract_row_liveness(table: &Table) -> Result>> { let mut all_table = HashMap::new(); let max = table.row.current_epoch(); let min = table.row.initial_epoch() + 1; @@ -1254,7 +758,7 @@ async fn extract_row_liveness(table: &Table) -> anyhow::Result anyhow::Result<(RowTreeKey, BlockRange)> { +) -> Result<(RowTreeKey, BlockRange)> { let initial_epoch = table.row.initial_epoch() + 1; let last_epoch = table.row.current_epoch(); let all_table = extract_row_liveness(table).await?; @@ -1290,10 +794,7 @@ pub(crate) async fn find_longest_lived_key( Ok((longest_key.clone(), (min_block, max_block))) } -async fn collect_all_at( - tree: &MerkleRowTree, - at: Epoch, -) -> anyhow::Result>> { +async fn collect_all_at(tree: &MerkleRowTree, at: Epoch) -> Result>> { let root_key = tree.root_at(at).await?.unwrap(); let (ctx, payload) = tree .try_fetch_with_context_at(&root_key, at) @@ -1352,7 +853,7 @@ fn find_longest_consecutive_sequence(v: Vec) -> (usize, i64) { async fn check_correct_cells_tree( all_cells: &[ColumnCell], payload: &RowPayload, -) -> anyhow::Result<()> { +) -> Result<()> { let local_cells = all_cells.to_vec(); let expected_cells_root = payload .cell_root_hash diff --git a/mp2-v1/tests/common/cases/query/mod.rs b/mp2-v1/tests/common/cases/query/mod.rs index 9e40f47d5..7887e635d 100644 --- a/mp2-v1/tests/common/cases/query/mod.rs +++ b/mp2-v1/tests/common/cases/query/mod.rs @@ -11,7 +11,10 @@ use log::info; use mp2_v1::{ api::MetadataHash, indexing::block::BlockPrimaryIndex, query::planner::execute_row_query, }; -use parsil::{parse_and_validate, utils::ParsilSettingsBuilder, PlaceholderSettings}; +use parsil::{ + assembler::DynamicCircuitPis, parse_and_validate, utils::ParsilSettingsBuilder, ParsilSettings, + PlaceholderSettings, +}; use simple_select_queries::{ cook_query_no_matching_rows, cook_query_too_big_offset, cook_query_with_distinct, cook_query_with_matching_rows, cook_query_with_max_num_matching_rows, @@ -23,13 +26,18 @@ use verifiable_db::query::{ computational_hash_ids::Output, universal_circuit::universal_circuit_inputs::Placeholders, }; -use crate::common::{cases::planner::QueryPlanner, table::Table, TableInfo, TestContext}; +use crate::common::{ + table::{Table, TableColumns}, + TableInfo, TestContext, +}; use super::table_source::TableSource; pub mod aggregated_queries; pub mod simple_select_queries; +pub const NUM_CHUNKS: usize = 5; +pub const NUM_ROWS: usize = 3; pub const MAX_NUM_RESULT_OPS: usize = 20; pub const MAX_NUM_OUTPUTS: usize = 3; pub const MAX_NUM_ITEMS_PER_OUTPUT: usize = 5; @@ -40,6 +48,8 @@ pub const ROW_TREE_MAX_DEPTH: usize = 10; pub const INDEX_TREE_MAX_DEPTH: usize = 15; pub type GlobalCircuitInput = verifiable_db::api::QueryCircuitInput< + NUM_CHUNKS, + NUM_ROWS, ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, @@ -51,6 +61,10 @@ pub type GlobalCircuitInput = verifiable_db::api::QueryCircuitInput< >; pub type QueryCircuitInput = verifiable_db::query::api::CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -78,6 +92,17 @@ pub struct QueryCooking { pub(crate) offset: Option, } +pub(crate) struct QueryPlanner<'a> { + pub(crate) query: QueryCooking, + pub(crate) pis: &'a DynamicCircuitPis, + pub(crate) ctx: &'a mut TestContext, + pub(crate) settings: &'a ParsilSettings<&'a Table>, + // useful for non existence since we need to search in both trees the places to prove + // the fact a given node doesn't exist + pub(crate) table: &'a Table, + pub(crate) columns: TableColumns, +} + pub async fn test_query(ctx: &mut TestContext, table: Table, t: TableInfo) -> Result<()> { match &t.source { TableSource::Mapping(_) | TableSource::Merge(_) => query_mapping(ctx, &table, &t).await?, @@ -198,17 +223,7 @@ async fn test_query_mapping( match pis.result.query_variant() { Output::Aggregation => { - prove_aggregation_query( - ctx, - table, - query_info, - parsed, - &settings, - res, - *table_hash, - pis, - ) - .await + prove_aggregation_query(parsed, res, *table_hash, &mut planner).await } Output::NoAggregation => { prove_no_aggregation_query(parsed, table_hash, &mut planner, res).await diff --git a/mp2-v1/tests/common/cases/query/simple_select_queries.rs b/mp2-v1/tests/common/cases/query/simple_select_queries.rs index e29226a8b..df01b78f7 100644 --- a/mp2-v1/tests/common/cases/query/simple_select_queries.rs +++ b/mp2-v1/tests/common/cases/query/simple_select_queries.rs @@ -5,12 +5,17 @@ use log::info; use mp2_common::types::HashOutput; use mp2_v1::{ api::MetadataHash, - indexing::{block::BlockPrimaryIndex, row::RowTreeKey, LagrangeNode}, - query::planner::execute_row_query, + indexing::{ + block::BlockPrimaryIndex, + row::{RowPayload, RowTreeKey}, + LagrangeNode, + }, + query::planner::{execute_row_query, get_node_info, TreeFetcher}, + values_extraction::identifier_block_column, }; use parsil::{ - executor::generate_query_execution_with_keys, DEFAULT_MAX_BLOCK_PLACEHOLDER, - DEFAULT_MIN_BLOCK_PLACEHOLDER, + assembler::DynamicCircuitPis, executor::generate_query_execution_with_keys, + DEFAULT_MAX_BLOCK_PLACEHOLDER, DEFAULT_MIN_BLOCK_PLACEHOLDER, }; use ryhope::{ storage::{pgsql::ToFromBytea, RoEpochKvStorage}, @@ -21,9 +26,11 @@ use std::{fmt::Debug, hash::Hash}; use tokio_postgres::Row as PgSqlRow; use verifiable_db::{ query::{ - aggregation::{ChildPosition, NodeInfo}, computational_hash_ids::ColumnIDs, - universal_circuit::universal_circuit_inputs::{PlaceholderId, Placeholders}, + universal_circuit::universal_circuit_inputs::{ + ColumnCell, PlaceholderId, Placeholders, RowCells, + }, + utils::{ChildPosition, NodeInfo}, }, revelation::{api::MatchingRow, RowPath}, test_utils::MAX_NUM_OUTPUTS, @@ -32,20 +39,17 @@ use verifiable_db::{ use crate::common::{ cases::{ indexing::BLOCK_COLUMN_NAME, - planner::{IndexInfo, QueryPlanner, RowInfo, TreeInfo}, query::{ - aggregated_queries::{ - check_final_outputs, find_longest_lived_key, get_node_info, prove_single_row, - }, - GlobalCircuitInput, RevelationCircuitInput, SqlReturn, SqlType, + aggregated_queries::{check_final_outputs, find_longest_lived_key}, + GlobalCircuitInput, QueryPlanner, RevelationCircuitInput, SqlReturn, SqlType, }, }, proof_storage::{ProofKey, ProofStorage}, - table::Table, + table::{Table, TableColumns}, TableInfo, }; -use super::QueryCooking; +use super::{QueryCircuitInput, QueryCooking, TestContext}; pub(crate) async fn prove_query( mut parsed: Query, @@ -80,24 +84,12 @@ pub(crate) async fn prove_query( }) .collect::>>()?; // compute input for each matching row - let row_tree_info = RowInfo { - satisfiying_rows: matching_rows - .iter() - .map(|(key, _, _)| key) - .cloned() - .collect(), - tree: &planner.table.row, - }; - let index_tree_info = IndexInfo { - bounds: (planner.query.min_block, planner.query.max_block), - tree: &planner.table.index, - }; - let current_epoch = index_tree_info.tree.current_epoch(); + let current_epoch = planner.table.index.current_epoch(); let mut matching_rows_input = vec![]; for (key, epoch, result) in matching_rows.into_iter() { let row_proof = prove_single_row( planner.ctx, - &row_tree_info, + &planner.table.row, &planner.columns, epoch as BlockPrimaryIndex, &key, @@ -105,13 +97,14 @@ pub(crate) async fn prove_query( &planner.query, ) .await?; - let (row_node_info, _, _) = get_node_info(&row_tree_info, &key, epoch).await; - let (row_tree_path, row_tree_siblings) = get_path_info(&key, &row_tree_info, epoch).await?; + let (row_node_info, _, _) = get_node_info(&planner.table.row, &key, epoch).await; + let (row_tree_path, row_tree_siblings) = + get_path_info(&key, &planner.table.row, epoch).await?; let index_node_key = epoch as BlockPrimaryIndex; let (index_node_info, _, _) = - get_node_info(&index_tree_info, &index_node_key, current_epoch).await; + get_node_info(&planner.table.index, &index_node_key, current_epoch).await; let (index_tree_path, index_tree_siblings) = - get_path_info(&index_node_key, &index_tree_info, current_epoch).await?; + get_path_info(&index_node_key, &planner.table.index, current_epoch).await?; let path = RowPath::new( row_node_info, row_tree_path, @@ -163,7 +156,7 @@ pub(crate) async fn prove_query( Ok(()) } -async fn get_path_info>( +async fn get_path_info>( key: &K, tree_info: &T, epoch: Epoch, @@ -175,7 +168,7 @@ where let mut tree_path = vec![]; let mut siblings = vec![]; let (mut node_ctx, mut node_payload) = tree_info - .fetch_ctx_and_payload_at(epoch, key) + .fetch_ctx_and_payload_at(key, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", key)))?; let mut previous_node_hash = node_payload.hash(); @@ -183,7 +176,7 @@ where while node_ctx.parent.is_some() { let parent_key = node_ctx.parent.unwrap(); (node_ctx, node_payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &parent_key) + .fetch_ctx_and_payload_at(&parent_key, epoch) .await .ok_or(Error::msg(format!( "Node not found for key {:?}", @@ -199,7 +192,7 @@ where match node_ctx.right { Some(k) => { let (_, payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &k) + .fetch_ctx_and_payload_at(&k, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", k)))?; Some(payload.hash()) @@ -212,7 +205,7 @@ where match node_ctx.left { Some(k) => { let (_, payload) = tree_info - .fetch_ctx_and_payload_at(epoch, &k) + .fetch_ctx_and_payload_at(&k, epoch) .await .ok_or(Error::msg(format!("Node not found for key {:?}", k)))?; Some(payload.hash()) @@ -250,6 +243,71 @@ where Ok((tree_path, siblings)) } +pub(crate) async fn prove_single_row>>( + ctx: &mut TestContext, + tree: &T, + columns: &TableColumns, + primary: BlockPrimaryIndex, + row_key: &RowTreeKey, + pis: &DynamicCircuitPis, + query: &QueryCooking, +) -> Result> { + // 1. Get the all the cells including primary and secondary index + // Note we can use the primary as epoch since now epoch == primary in the storage + let (row_ctx, row_payload) = tree + .fetch_ctx_and_payload_at(row_key, primary as Epoch) + .await + .expect("cache not full"); + + // API is gonna change on this but right now, we have to sort all the "rest" cells by index + // in the tree, and put the primary one and secondary one in front + let rest_cells = columns + .non_indexed_columns() + .iter() + .map(|tc| tc.identifier) + .filter_map(|id| { + row_payload + .cells + .find_by_column(id) + .map(|info| ColumnCell::new(id, info.value)) + }) + .collect::>(); + + let secondary_cell = ColumnCell::new( + row_payload.secondary_index_column, + row_payload.secondary_index_value(), + ); + let primary_cell = ColumnCell::new(identifier_block_column(), U256::from(primary)); + let row = RowCells::new(primary_cell, secondary_cell, rest_cells); + // 2. create input + let input = QueryCircuitInput::new_universal_circuit( + &row, + &pis.predication_operations, + &pis.result, + &query.placeholders, + row_ctx.is_leaf(), + &pis.bounds, + ) + .expect("unable to create universal query circuit inputs"); + // 3. run proof if not ran already + let proof_key = ProofKey::QueryUniversal(( + query.query.clone(), + query.placeholders.placeholder_values(), + primary, + row_key.clone(), + )); + let proof = { + info!("Universal query proof RUNNING for {primary} -> {row_key:?} "); + let proof = ctx + .run_query_proof("querying::universal", GlobalCircuitInput::Query(input)) + .expect("unable to generate universal proof for {epoch} -> {row_key:?}"); + info!("Universal query proof DONE for {primary} -> {row_key:?} "); + ctx.storage.store_proof(proof_key, proof.clone())?; + proof + }; + Ok(proof) +} + /// Cook a query where the number of matching rows is the same as the maximum number of /// outputs allowed pub(crate) async fn cook_query_with_max_num_matching_rows( diff --git a/mp2-v1/tests/common/context.rs b/mp2-v1/tests/common/context.rs index f7ae3e7a0..83ea8fe1a 100644 --- a/mp2-v1/tests/common/context.rs +++ b/mp2-v1/tests/common/context.rs @@ -28,7 +28,8 @@ use super::{ self, query::{ INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, - MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, ROW_TREE_MAX_DEPTH, + MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, NUM_CHUNKS, NUM_ROWS, + ROW_TREE_MAX_DEPTH, }, }, proof_storage::ProofKV, @@ -56,6 +57,8 @@ pub(crate) struct TestContext { pub(crate) params: Option, pub(crate) query_params: Option< verifiable_db::api::QueryParameters< + NUM_CHUNKS, + NUM_ROWS, ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, diff --git a/mp2-v1/tests/common/proof_storage.rs b/mp2-v1/tests/common/proof_storage.rs index f0cc91058..059d8aa56 100644 --- a/mp2-v1/tests/common/proof_storage.rs +++ b/mp2-v1/tests/common/proof_storage.rs @@ -3,7 +3,7 @@ use std::{ path::{Path, PathBuf}, }; -use super::{context::TestContextConfig, mkdir_all, table::TableID}; +use super::{cases::query::NUM_CHUNKS, context::TestContextConfig, mkdir_all, table::TableID}; use alloy::primitives::{Address, U256}; use anyhow::{bail, Context, Result}; use envconfig::Envconfig; @@ -68,8 +68,13 @@ pub enum ProofKey { #[allow(clippy::upper_case_acronyms)] IVC(BlockPrimaryIndex), QueryUniversal((QueryID, PlaceholderValues, BlockPrimaryIndex, RowTreeKey)), - QueryAggregateRow((QueryID, PlaceholderValues, BlockPrimaryIndex, RowTreeKey)), - QueryAggregateIndex((QueryID, PlaceholderValues, BlockPrimaryIndex)), + QueryAggregate( + ( + QueryID, + PlaceholderValues, + mp2_v1::query::batching_planner::UTKey, + ), + ), } impl ProofKey { @@ -123,12 +128,8 @@ impl Hash for ProofKey { "query_universal".hash(state); n.hash(state); } - ProofKey::QueryAggregateRow(n) => { - "query_aggregate_row".hash(state); - n.hash(state); - } - ProofKey::QueryAggregateIndex(n) => { - "query_aggregate_index".hash(state); + ProofKey::QueryAggregate(n) => { + "query_aggregate".hash(state); n.hash(state); } } diff --git a/parsil/src/assembler.rs b/parsil/src/assembler.rs index bb4e22c1d..e847f9c0a 100644 --- a/parsil/src/assembler.rs +++ b/parsil/src/assembler.rs @@ -15,11 +15,11 @@ use sqlparser::ast::{ SelectItem, SetExpr, TableAlias, TableFactor, UnaryOperator, Value, }; use verifiable_db::query::{ - aggregation::{QueryBoundSource, QueryBounds}, computational_hash_ids::{AggregationOperation, Operation, PlaceholderIdentifier}, universal_circuit::universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, Placeholders, ResultStructure, }, + utils::{QueryBoundSource, QueryBounds}, }; use crate::{ diff --git a/parsil/src/bracketer.rs b/parsil/src/bracketer.rs index 6b7358a2c..7a4908716 100644 --- a/parsil/src/bracketer.rs +++ b/parsil/src/bracketer.rs @@ -1,6 +1,6 @@ use alloy::primitives::U256; use ryhope::{KEY, PAYLOAD, VALID_FROM, VALID_UNTIL}; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; use crate::{symbols::ContextProvider, ParsilSettings}; diff --git a/parsil/src/isolator.rs b/parsil/src/isolator.rs index 66014d903..ca145145d 100644 --- a/parsil/src/isolator.rs +++ b/parsil/src/isolator.rs @@ -3,7 +3,7 @@ use anyhow::*; use log::warn; use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SelectItem, TableAlias, TableFactor}; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; use crate::{ errors::ValidationError, diff --git a/parsil/src/lib.rs b/parsil/src/lib.rs index 499f4b06d..df88aeaa4 100644 --- a/parsil/src/lib.rs +++ b/parsil/src/lib.rs @@ -7,7 +7,7 @@ pub use utils::ParsilSettings; pub use utils::PlaceholderSettings; pub use utils::DEFAULT_MAX_BLOCK_PLACEHOLDER; pub use utils::DEFAULT_MIN_BLOCK_PLACEHOLDER; -use verifiable_db::query::aggregation::QueryBounds; +use verifiable_db::query::utils::QueryBounds; pub mod assembler; pub mod bracketer; diff --git a/parsil/src/queries.rs b/parsil/src/queries.rs index 2efeefc1a..506fdb731 100644 --- a/parsil/src/queries.rs +++ b/parsil/src/queries.rs @@ -5,7 +5,7 @@ use crate::{keys_in_index_boundaries, symbols::ContextProvider, ParsilSettings}; use anyhow::*; use ryhope::{tree::sbbst::NodeIdx, Epoch, EPOCH, KEY, VALID_FROM, VALID_UNTIL}; use verifiable_db::query::{ - aggregation::QueryBounds, universal_circuit::universal_circuit_inputs::Placeholders, + universal_circuit::universal_circuit_inputs::Placeholders, utils::QueryBounds, }; /// Return a query read to be injected in the wide lineage computation for the diff --git a/rust-toolchain b/rust-toolchain index bf867e0ae..a7a456242 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly +nightly-2024-12-03 diff --git a/ryhope/src/lib.rs b/ryhope/src/lib.rs index ea40dc5cd..3caaf90f7 100644 --- a/ryhope/src/lib.rs +++ b/ryhope/src/lib.rs @@ -165,7 +165,7 @@ where let mut child_data = vec![]; for c in c.iter_children() { if let Some(k) = c { - child_data.push(Some(self.storage.data().try_fetch(k).await?)); + child_data.push(self.storage.data().try_fetch(k).await?); } else { child_data.push(None); } @@ -177,7 +177,7 @@ where .try_fetch(item.k()) .await? .expect("the node can not not be present"); - payload.aggregate(child_data.into_iter().flatten()); + payload.aggregate(child_data.into_iter()); plan.done(&item)?; self.storage .data_mut() diff --git a/ryhope/src/storage/updatetree.rs b/ryhope/src/storage/updatetree.rs index 2efd9594f..dc59b3f8c 100644 --- a/ryhope/src/storage/updatetree.rs +++ b/ryhope/src/storage/updatetree.rs @@ -54,6 +54,10 @@ impl UpdateTree { fn node_mut(&mut self, i: usize) -> &mut UpdateTreeNode { &mut self.nodes[i] } + + pub fn node_from_key(&self, k: &K) -> Option<&UpdateTreeNode> { + self.idx.get(k).map(|i| self.node(*i)) + } } impl UpdateTree { diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index bd5ab8c75..8b50d33d5 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -20,6 +20,7 @@ serde.workspace = true mp2_common = { path = "../mp2-common" } recursion_framework = { path = "../recursion-framework" } ryhope = { path = "../ryhope" } +mp2_test = { path = "../mp2-test" } [dev-dependencies] futures.workspace = true @@ -27,7 +28,6 @@ rand.workspace = true serial_test.workspace = true tokio.workspace = true -mp2_test = { path = "../mp2-test" } - [features] original_poseidon = ["mp2_common/original_poseidon"] +results_tree = [] # temporary features to disable compiling results_tree code by default, as it is still WiP diff --git a/verifiable-db/src/api.rs b/verifiable-db/src/api.rs index 2429353ed..9c1ca5324 100644 --- a/verifiable-db/src/api.rs +++ b/verifiable-db/src/api.rs @@ -5,9 +5,7 @@ use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, ivc, query::{self, api::Parameters as QueryParams, pi_len as query_pi_len}, - revelation::{ - self, api::Parameters as RevelationParams, num_query_io, pi_len as revelation_pi_len, - }, + revelation::{self, api::Parameters as RevelationParams, pi_len as revelation_pi_len}, row_tree::{self}, }; use anyhow::Result; @@ -193,24 +191,35 @@ where #[derive(Serialize, Deserialize)] pub struct QueryParameters< - const ROW_TREE_MAX_DEPTH: usize, - const INDEX_TREE_MAX_DEPTH: usize, - const MAX_NUM_COLUMNS: usize, - const MAX_NUM_PREDICATE_OPS: usize, - const MAX_NUM_RESULT_OPS: usize, - const MAX_NUM_OUTPUTS: usize, - const MAX_NUM_ITEMS_PER_OUTPUT: usize, - const MAX_NUM_PLACEHOLDERS: usize, + const NUM_CHUNKS: usize, // Maximum number of chunks that can be aggregated in a single proof + const NUM_ROWS: usize, // Maximum number of rows that can be proven in a single proof + const ROW_TREE_MAX_DEPTH: usize, // Maximum depth of rows tree supported in circuits + const INDEX_TREE_MAX_DEPTH: usize, // Maximum depth of index tree supported in circuits + const MAX_NUM_COLUMNS: usize, // Maximum number of columns for a table supported in circuits + const MAX_NUM_PREDICATE_OPS: usize, // Maximum number of operations that can be employed in a query + // to evaluate the filtering predicate (i.e, the operations in `WHERE` clause of the query) + const MAX_NUM_RESULT_OPS: usize, // Maximum number of operations that can be employed in a query + // to compute the results of the query in each row (i.e, the operations in the `SELECT` clause of the query) + const MAX_NUM_OUTPUTS: usize, // Maximum number of outputs that can be returned for a query with a single + // proof. It basically corresponds to the maximum value that can be used for `LIMIT` keyword + const MAX_NUM_ITEMS_PER_OUTPUT: usize, // Maximum number of items that can be returned for each output row of the + // query (i.e., the maximum number of items that can be specified in the `SELECT` clause of the query) + const MAX_NUM_PLACEHOLDERS: usize, // Maximum number of placeholders (including the special block range + // placeholders) that can be employed in a query > where [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); num_query_io::()]:, + [(); query_pi_len::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, { query_params: QueryParams< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -233,6 +242,8 @@ pub struct QueryParameters< #[derive(Serialize, Deserialize)] #[allow(clippy::large_enum_variant)] pub enum QueryCircuitInput< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, const ROW_TREE_MAX_DEPTH: usize, const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, @@ -249,6 +260,10 @@ pub enum QueryCircuitInput< { Query( query::api::CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -270,6 +285,8 @@ pub enum QueryCircuitInput< } impl< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, const ROW_TREE_MAX_DEPTH: usize, const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, @@ -280,6 +297,8 @@ impl< const MAX_NUM_PLACEHOLDERS: usize, > QueryParameters< + NUM_CHUNKS, + NUM_ROWS, ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, @@ -292,7 +311,6 @@ impl< where [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); num_query_io::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, @@ -306,7 +324,8 @@ where let query_params = QueryParams::build(); info!("Building the revelation circuit parameters..."); let revelation_params = RevelationParams::build( - query_params.get_circuit_set(), + query_params.get_circuit_set(), // unused, so we provide same query params + query_params.get_universal_circuit().data.verifier_data(), ¶ms_info.preprocessing_circuit_set, ¶ms_info.preprocessing_vk, ); @@ -323,6 +342,8 @@ where pub fn generate_proof( &self, input: QueryCircuitInput< + NUM_CHUNKS, + NUM_ROWS, ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, @@ -339,7 +360,7 @@ where let proof = self.revelation_params.generate_proof( input, self.query_params.get_circuit_set(), - Some(&self.query_params), + Some(self.query_params.get_universal_circuit()), )?; self.wrap_circuit.generate_proof( self.revelation_params.get_circuit_set(), @@ -361,6 +382,8 @@ mod tests { use std::{fs::File, io::BufReader}; // Constants associating with test data. + const NUM_CHUNKS: usize = 5; + const NUM_ROWS: usize = 5; const MAX_NUM_COLUMNS: usize = 20; const MAX_NUM_PREDICATE_OPS: usize = 20; const MAX_NUM_RESULT_OPS: usize = 20; @@ -381,6 +404,8 @@ mod tests { let file = File::open(QUERY_PARAMS_FILE_PATH).unwrap(); let reader = BufReader::new(file); let query_params: QueryParameters< + NUM_CHUNKS, + NUM_ROWS, ROW_TREE_MAX_DEPTH, INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, diff --git a/verifiable-db/src/lib.rs b/verifiable-db/src/lib.rs index e67983aba..b9e0856fe 100644 --- a/verifiable-db/src/lib.rs +++ b/verifiable-db/src/lib.rs @@ -4,6 +4,7 @@ // Add this to allow generic const items, e.g. `const IO_LEN` #![feature(generic_const_items)] #![feature(variant_count)] +#![feature(async_closure)] pub mod api; pub mod block_tree; pub mod cells_tree; @@ -11,6 +12,7 @@ pub mod extraction; pub mod ivc; /// Module for circuits for simple queries pub mod query; +#[cfg(feature = "results_tree")] pub mod results_tree; /// Module for the query revelation circuits pub mod revelation; diff --git a/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs b/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs deleted file mode 100644 index 90f0c5120..000000000 --- a/verifiable-db/src/query/aggregation/child_proven_single_path_node.rs +++ /dev/null @@ -1,366 +0,0 @@ -use std::iter; - -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; - -use crate::query::public_inputs::PublicInputs; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ChildProvenSinglePathNodeWires { - value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - subtree_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - sibling_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - unproven_min: UInt256Target, - unproven_max: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ChildProvenSinglePathNodeCircuit { - /// Value stored in the current node - pub(crate) value: U256, - /// Hash of the row/rows tree stored in the current node - pub(crate) subtree_hash: HashOut, - /// Hash of the sibling of the proven node child - pub(crate) sibling_hash: HashOut, - /// Flag indicating whether the proven child is the left child or the right one - pub(crate) is_left_child: bool, - /// Minimum value of the indexed column to be employed to compute the hash of the current node - pub(crate) unproven_min: U256, - /// Maximum value of the indexed column to be employed to compute the hash of the current node - pub(crate) unproven_max: U256, - /// Boolean flag specifying whether the proof is being generated for a node - /// in a rows tree of for a node in the index tree - pub(crate) is_rows_tree_node: bool, -} - -impl ChildProvenSinglePathNodeCircuit { - pub fn build( - b: &mut CBuilder, - child_proof: &PublicInputs, - ) -> ChildProvenSinglePathNodeWires { - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let value = b.add_virtual_u256(); - let subtree_hash = b.add_virtual_hash(); - let sibling_hash = b.add_virtual_hash(); - let unproven_min = b.add_virtual_u256_unsafe(); - let unproven_max = b.add_virtual_u256_unsafe(); - - let node_min = b.select_u256( - is_left_child, - &child_proof.min_value_target(), - &unproven_min, - ); - let node_max = b.select_u256( - is_left_child, - &unproven_max, - &child_proof.max_value_target(), - ); - let column_id = b.select( - is_rows_tree_node, - child_proof.index_ids_target()[1], - child_proof.index_ids_target()[0], - ); - // Compute the node hash: - // node_hash = H(left_child_hash||right_child_hash||node_min||node_max||column_id||value||subtree_hash) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(value.to_targets()) - .chain(subtree_hash.elements) - .collect(); - - let node_hash = hash_maybe_first( - b, - is_left_child, - sibling_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // if is_left_child: - // value > child_proof.max_query - // else: - // value < child_proof.min_query - let is_greater_than_max = b.is_greater_than_u256(&value, &child_proof.max_query_target()); - let is_less_than_min = b.is_less_than_u256(&value, &child_proof.min_query_target()); - let condition = b.select( - is_left_child, - is_greater_than_max.target, - is_less_than_min.target, - ); - let ttrue = b._true(); - b.connect(condition, ttrue.target); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - child_proof.to_values_raw(), - &[child_proof.num_matching_rows_target()], - child_proof.to_ops_raw(), - child_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - child_proof.to_index_ids_raw(), - child_proof.to_min_query_raw(), - child_proof.to_max_query_raw(), - &[*child_proof.to_overflow_raw()], - child_proof.to_computational_hash_raw(), - child_proof.to_placeholder_hash_raw(), - ) - .register(b); - - ChildProvenSinglePathNodeWires { - value, - subtree_hash, - sibling_hash, - is_left_child, - unproven_min, - unproven_max, - is_rows_tree_node, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &ChildProvenSinglePathNodeWires, - ) { - pw.set_u256_target(&wires.value, self.value); - pw.set_hash_target(wires.subtree_hash, self.subtree_hash); - pw.set_hash_target(wires.sibling_hash, self.sibling_hash); - pw.set_bool_target(wires.is_left_child, self.is_left_child); - pw.set_u256_target(&wires.unproven_min, self.unproven_min); - pw.set_u256_target(&wires.unproven_max, self.unproven_max); - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - } -} - -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for ChildProvenSinglePathNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = ChildProvenSinglePathNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - let child_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::pi_len, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C, D}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use rand::{thread_rng, Rng}; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestChildProvenSinglePathNodeCircuit<'a> { - c: ChildProvenSinglePathNodeCircuit, - child_proof: &'a [F], - } - - impl UserCircuit for TestChildProvenSinglePathNodeCircuit<'_> { - type Wires = (ChildProvenSinglePathNodeWires, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let child_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let pi = PublicInputs::::from_slice(&child_proof); - - let wires = ChildProvenSinglePathNodeCircuit::build(b, &pi); - - (wires, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.child_proof); - } - } - - fn test_child_proven_single_path_node_circuit(is_rows_tree_node: bool, is_left_child: bool) { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Build the child proof. - let [child_proof] = random_aggregation_public_inputs(&ops); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - let index_ids = child_pi.index_ids(); - let index_value = child_pi.index_value(); - let min_query = child_pi.min_query_value(); - let max_query = child_pi.max_query_value(); - - // Construct the witness. - let mut rng = thread_rng(); - let mut value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let subtree_hash = gen_random_field_hash(); - let sibling_hash = gen_random_field_hash(); - let unproven_min = index_value - .checked_sub(U256::from(100)) - .unwrap_or(index_value); - let unproven_max = index_value - .checked_add(U256::from(100)) - .unwrap_or(index_value); - - if is_left_child { - while value <= max_query { - value = U256::from_limbs(rng.gen::<[u64; 4]>()); - } - } else { - while value >= min_query { - value = U256::from_limbs(rng.gen::<[u64; 4]>()); - } - } - - // Construct the test circuit. - let test_circuit = TestChildProvenSinglePathNodeCircuit { - c: ChildProvenSinglePathNodeCircuit { - value, - subtree_hash, - sibling_hash, - is_left_child, - unproven_min, - unproven_max, - is_rows_tree_node, - }, - child_proof: &child_proof, - }; - - // Proof for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), unproven_max] - } else { - [unproven_min, child_pi.max_value()] - }; - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, sibling_hash] - } else { - [sibling_hash, child_hash] - }; - - // H(left_child_hash||right_child_hash||node_min||node_max||column_id||value||subtree_hash) - let input: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(value.to_fields()) - .chain(subtree_hash.to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&input); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), child_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), child_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), child_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), child_pi.overflow_flag()); - // Computational hash - assert_eq!(pi.computational_hash(), child_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), child_pi.placeholder_hash()); - } - - #[test] - fn test_child_proven_node_for_row_node_with_left_child() { - test_child_proven_single_path_node_circuit(true, true); - } - #[test] - fn test_child_proven_node_for_row_node_with_right_child() { - test_child_proven_single_path_node_circuit(true, false); - } - #[test] - fn test_child_proven_node_for_index_node_with_left_child() { - test_child_proven_single_path_node_circuit(false, true); - } - #[test] - fn test_child_proven_node_for_index_node_with_right_child() { - test_child_proven_single_path_node_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs b/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs deleted file mode 100644 index 2e2c26056..000000000 --- a/verifiable-db/src/query/aggregation/embedded_tree_proven_single_path_node.rs +++ /dev/null @@ -1,572 +0,0 @@ -use std::iter; - -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - serialization::{deserialize, deserialize_array, serialize, serialize_array}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{SelectHashBuilder, ToTargets}, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::array; - -use crate::query::public_inputs::PublicInputs; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct EmbeddedTreeProvenSinglePathNodeWires { - left_child_min: UInt256Target, - left_child_max: UInt256Target, - left_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - left_grand_children: [HashOutTarget; 2], - right_child_min: UInt256Target, - right_child_max: UInt256Target, - right_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - right_grand_children: [HashOutTarget; 2], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct EmbeddedTreeProvenSinglePathNodeCircuit { - /// Minimum value associated to the left child - pub(crate) left_child_min: U256, - /// Maximum value associated to the left child - pub(crate) left_child_max: U256, - /// Value stored in the left child - pub(crate) left_child_value: U256, - /// Hashes of the row/rows tree stored in the left child - pub(crate) left_tree_hash: HashOut, - /// Hashes of the children nodes of the left child - pub(crate) left_grand_children: [HashOut; 2], - /// Minimum value associated to the right child - pub(crate) right_child_min: U256, - /// Maximum value associated to the right child - pub(crate) right_child_max: U256, - /// Value stored in the right child - pub(crate) right_child_value: U256, - /// Hashes of the row/rows tree stored in the right child - pub(crate) right_tree_hash: HashOut, - /// Hashes of the children nodes of the right child - pub(crate) right_grand_children: [HashOut; 2], - /// Boolean flag specifying whether there is a left child for the current node - pub(crate) left_child_exists: bool, - /// Boolean flag specifying whether there is a right child for the current node - pub(crate) right_child_exists: bool, - /// Boolean flag specifying whether the proof is being generated - /// for a node in a rows tree or for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl EmbeddedTreeProvenSinglePathNodeCircuit { - pub fn build( - b: &mut CBuilder, - embedded_tree_proof: &PublicInputs, - ) -> EmbeddedTreeProvenSinglePathNodeWires { - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let [left_child_min, left_child_max, left_child_value, right_child_min, right_child_max, right_child_value, min_query, max_query] = - array::from_fn(|_| b.add_virtual_u256_unsafe()); - let [left_tree_hash, right_tree_hash] = array::from_fn(|_| b.add_virtual_hash()); - let left_grand_children: [HashOutTarget; 2] = array::from_fn(|_| b.add_virtual_hash()); - let right_grand_children: [HashOutTarget; 2] = array::from_fn(|_| b.add_virtual_hash()); - let [left_child_exists, right_child_exists, is_rows_tree_node] = - array::from_fn(|_| b.add_virtual_bool_target_safe()); - - let index_value = embedded_tree_proof.index_value_target(); - - let column_id = b.select( - is_rows_tree_node, - embedded_tree_proof.index_ids_target()[1], - embedded_tree_proof.index_ids_target()[0], - ); - - let node_value = b.select_u256( - is_rows_tree_node, - &embedded_tree_proof.min_value_target(), - &index_value, - ); - - // H(left_grandchild_1||left_grandchild_2||left_min||left_max||column_id||left_value||left_tree_hash) - let left_child_inputs = left_grand_children[0] - .to_targets() - .into_iter() - .chain(left_grand_children[1].to_targets()) - .chain(left_child_min.to_targets()) - .chain(left_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_targets()) - .chain(left_tree_hash.to_targets()) - .collect(); - let left_hash_exists = b.hash_n_to_hash_no_pad::(left_child_inputs); - let left_child_hash = b.select_hash(left_child_exists, &left_hash_exists, &empty_hash); - - // H(right_grandchild_1||right_grandchild_2||right_min||right_max||column_id||right_value||right_tree_hash) - let right_child_inputs = right_grand_children[0] - .to_targets() - .into_iter() - .chain(right_grand_children[1].to_targets()) - .chain(right_child_min.to_targets()) - .chain(right_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_targets()) - .chain(right_tree_hash.to_targets()) - .collect(); - let right_hash_exists = b.hash_n_to_hash_no_pad::(right_child_inputs); - let right_child_hash = b.select_hash(right_child_exists, &right_hash_exists, &empty_hash); - - let node_min = b.select_u256(left_child_exists, &left_child_min, &node_value); - let node_max = b.select_u256(right_child_exists, &right_child_max, &node_value); - - // If the current node is not a rows tree, we need to ensure that - // the value of the primary indexed column for all the records stored in the rows tree - // found in this node is within the range specified by the query: - // min_i1 <= index_value <= max_i1 - // -> NOT((index_value < min_i1) OR (index_value > max_i1)) - let is_less_than = b.is_less_than_u256(&index_value, &min_query); - let is_greater_than = b.is_greater_than_u256(&index_value, &max_query); - let is_out_of_range = b.or(is_less_than, is_greater_than); - let is_within_range = b.not(is_out_of_range); - - // If the current node is in a rows tree, we need to ensure that - // the query bounds exposed as public inputs are the same as the one exposed - // by the proof for the row associated to the current node - let is_min_same = b.is_equal_u256(&embedded_tree_proof.min_query_target(), &min_query); - let is_max_same = b.is_equal_u256(&embedded_tree_proof.max_query_target(), &max_query); - let are_query_bounds_same = b.and(is_min_same, is_max_same); - - // if is_rows_tree_node: - // embedded_tree_proof.min_query == min_query && - // embedded_tree_proof.max_query == max_query - // else if not is_rows_tree_node: - // min_query <= index_value <= max_query - let rows_tree_condition = b.select( - is_rows_tree_node, - are_query_bounds_same.target, - is_within_range.target, - ); - let ttrue = b._true(); - b.connect(rows_tree_condition, ttrue.target); - - // Enforce that the subtree rooted in the left child contains - // only nodes outside of the range specified by the query - let is_less_than_min = b.is_less_than_u256(&left_child_max, &min_query); - let left_condition = b.and(left_child_exists, is_less_than_min); - // (left_child_exists AND (left_child_max < min_query)) == left_child_exists - b.connect(left_condition.target, left_child_exists.target); - - // Enforce that the subtree rooted in the right child contains - // only nodes outside of the range specified by the query - let is_greater_than_max = b.is_greater_than_u256(&right_child_min, &max_query); - let right_condition = b.and(right_child_exists, is_greater_than_max); - // (right_child_exists AND (right_child_min > min_query)) == right_child_exists - b.connect(right_condition.target, right_child_exists.target); - - // H(left_child_hash||right_child_hash||node_min||node_max||column_id||node_value||p.H) - let node_hash_inputs = left_child_hash - .elements - .into_iter() - .chain(right_child_hash.elements) - .chain(node_min.to_targets()) - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(embedded_tree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(node_hash_inputs); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - embedded_tree_proof.to_values_raw(), - &[embedded_tree_proof.num_matching_rows_target()], - embedded_tree_proof.to_ops_raw(), - embedded_tree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - embedded_tree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[*embedded_tree_proof.to_overflow_raw()], - embedded_tree_proof.to_computational_hash_raw(), - embedded_tree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - EmbeddedTreeProvenSinglePathNodeWires { - left_child_min, - left_child_max, - left_child_value, - left_tree_hash, - left_grand_children, - right_child_min, - right_child_max, - right_child_value, - right_tree_hash, - right_grand_children, - left_child_exists, - right_child_exists, - is_rows_tree_node, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &EmbeddedTreeProvenSinglePathNodeWires, - ) { - [ - (&wires.left_child_min, self.left_child_min), - (&wires.left_child_max, self.left_child_max), - (&wires.left_child_value, self.left_child_value), - (&wires.right_child_min, self.right_child_min), - (&wires.right_child_max, self.right_child_max), - (&wires.right_child_value, self.right_child_value), - (&wires.min_query, self.min_query), - (&wires.max_query, self.max_query), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - [ - (wires.left_tree_hash, self.left_tree_hash), - (wires.right_tree_hash, self.right_tree_hash), - ] - .iter() - .for_each(|(t, v)| pw.set_hash_target(*t, *v)); - wires - .left_grand_children - .iter() - .zip(self.left_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - wires - .right_grand_children - .iter() - .zip(self.right_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - [ - (wires.left_child_exists, self.left_child_exists), - (wires.right_child_exists, self.right_child_exists), - (wires.is_rows_tree_node, self.is_rows_tree_node), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - } -} - -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for EmbeddedTreeProvenSinglePathNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = EmbeddedTreeProvenSinglePathNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - let child_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use mp2_common::{utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::plonk::config::Hasher; - use rand::{thread_rng, Rng}; - - use crate::{ - query::pi_len, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestEmbeddedTreeProvenSinglePathNodeCircuit<'a> { - c: EmbeddedTreeProvenSinglePathNodeCircuit, - embedded_tree_proof: &'a [F], - } - - impl UserCircuit for TestEmbeddedTreeProvenSinglePathNodeCircuit<'_> { - type Wires = ( - EmbeddedTreeProvenSinglePathNodeWires, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let embedded_tree_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let pi = PublicInputs::::from_slice(&embedded_tree_proof); - - let wires = EmbeddedTreeProvenSinglePathNodeCircuit::build(b, &pi); - - (wires, embedded_tree_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.embedded_tree_proof); - } - } - - fn test_embedded_tree_proven_single_path_node_circuit( - is_rows_tree_node: bool, - left_child_exists: bool, - right_child_exists: bool, - ) { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Build the subtree proof. - let [embdeed_tree_proof] = random_aggregation_public_inputs(&ops); - let embedded_tree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&embdeed_tree_proof); - - let index_ids = embedded_tree_pi.index_ids(); - let index_value = embedded_tree_pi.index_value(); - - // Construct the witness. - let mut rng = thread_rng(); - let [left_child_min, mut left_child_max, left_child_value, mut right_child_min, right_child_max, right_child_value] = - array::from_fn(|_| U256::from_limbs(rng.gen::<[u64; 4]>())); - let left_tree_hash = gen_random_field_hash(); - let left_grand_children: [HashOut; 2] = array::from_fn(|_| gen_random_field_hash()); - let right_tree_hash = gen_random_field_hash(); - let right_grand_children: [HashOut; 2] = array::from_fn(|_| gen_random_field_hash()); - let mut min_query = U256::from(100); - let mut max_query = U256::from(200); - - if is_rows_tree_node { - min_query = embedded_tree_pi.min_query_value(); - max_query = embedded_tree_pi.max_query_value(); - } else { - if min_query > index_value { - min_query = index_value - U256::from(1); - } - if max_query < index_value { - max_query = index_value + U256::from(1); - } - } - - if left_child_exists { - left_child_max = min_query - U256::from(1); - } - - if right_child_exists { - right_child_min = max_query + U256::from(1); - } - - // Construct the test circuit. - let test_circuit = TestEmbeddedTreeProvenSinglePathNodeCircuit { - c: EmbeddedTreeProvenSinglePathNodeCircuit { - left_child_min, - left_child_max, - left_child_value, - left_tree_hash, - left_grand_children, - right_child_min, - right_child_max, - right_child_value, - right_tree_hash, - right_grand_children, - left_child_exists, - right_child_exists, - is_rows_tree_node, - min_query, - max_query, - }, - embedded_tree_proof: &embdeed_tree_proof, - }; - - // Proof for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - let node_value = if is_rows_tree_node { - embedded_tree_pi.min_value() - } else { - index_value - }; - let node_min = if left_child_exists { - left_child_min - } else { - node_value - }; - let node_max = if right_child_exists { - right_child_max - } else { - node_value - }; - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - let empty_hash = empty_poseidon_hash(); - // H(left_grandchild_1||left_grandchild_2||left_min||left_max||column_id||left_value||left_subtree_hash) - let left_child_inputs: Vec<_> = left_grand_children[0] - .to_fields() - .into_iter() - .chain(left_grand_children[1].to_fields()) - .chain(left_child_min.to_fields()) - .chain(left_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_fields()) - .chain(left_tree_hash.to_fields()) - .collect(); - let left_hash_exists = H::hash_no_pad(&left_child_inputs); - let left_child_hash = if left_child_exists { - left_hash_exists - } else { - *empty_hash - }; - // H(right_grandchild_1||right_grandchild_2||right_min||right_max||column_id||right_value||right_subtree_hash) - let right_child_inputs: Vec<_> = right_grand_children[0] - .to_fields() - .into_iter() - .chain(right_grand_children[1].to_fields()) - .chain(right_child_min.to_fields()) - .chain(right_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_fields()) - .chain(right_tree_hash.to_fields()) - .collect(); - let right_hash_exists = H::hash_no_pad(&right_child_inputs); - let right_child_hash = if right_child_exists { - right_hash_exists - } else { - *empty_hash - }; - - let node_hash_input: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(embedded_tree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&node_hash_input); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), embedded_tree_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), embedded_tree_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), embedded_tree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), embedded_tree_pi.overflow_flag()); - // Computational hash - assert_eq!( - pi.computational_hash(), - embedded_tree_pi.computational_hash() - ); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), embedded_tree_pi.placeholder_hash()); - } - - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_no_child() { - test_embedded_tree_proven_single_path_node_circuit(true, false, false); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_left_child() { - test_embedded_tree_proven_single_path_node_circuit(true, true, false); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_right_child() { - test_embedded_tree_proven_single_path_node_circuit(true, false, true); - } - #[test] - fn test_embedded_tree_proven_node_for_row_node_with_both_children() { - test_embedded_tree_proven_single_path_node_circuit(true, true, true); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_no_child() { - test_embedded_tree_proven_single_path_node_circuit(false, false, false); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_left_child() { - test_embedded_tree_proven_single_path_node_circuit(false, true, false); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_right_child() { - test_embedded_tree_proven_single_path_node_circuit(false, false, true); - } - #[test] - fn test_embedded_tree_proven_node_for_index_node_with_both_children() { - test_embedded_tree_proven_single_path_node_circuit(false, true, true); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_index_leaf.rs b/verifiable-db/src/query/aggregation/full_node_index_leaf.rs deleted file mode 100644 index ffe02d5aa..000000000 --- a/verifiable-db/src/query/aggregation/full_node_index_leaf.rs +++ /dev/null @@ -1,246 +0,0 @@ -//! Module handling the leaf full node of the index tree for query aggregation circuits - -use crate::query::public_inputs::PublicInputs; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{target::Target, witness::PartialWitness}, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::iter; - -/// Leaf wires -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeIndexLeafWires { - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeIndexLeafCircuit { - /// Minimum range bound specified in the query for the indexed column - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeIndexLeafCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - ) -> FullNodeIndexLeafWires { - let ttrue = b._true(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - let empty_hash_targets = empty_hash.to_targets(); - - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256()); - - let index_ids = subtree_proof.index_ids_target(); - let index_value = subtree_proof.index_value_target(); - let index_value_targets = subtree_proof.to_index_value_raw(); - - // Ensure the value of the indexed column for all the records stored in the - // subtree found in this node is within the range specified by the query: - // p.I >= MIN_query AND p.I <= MAX_query - let is_not_less_than_min = b.is_less_or_equal_than_u256(&min_query, &index_value); - let is_not_greater_than_max = b.is_less_or_equal_than_u256(&index_value, &max_query); - let is_in_range = b.and(is_not_less_than_min, is_not_greater_than_max); - b.connect(is_in_range.target, ttrue.target); - - // Compute the node hash: - // node_hash = H(H("") || H("") || p.I || p.I || p.index_ids[0] || p.I || p.H)) - let inputs = empty_hash_targets - .iter() - .chain(empty_hash_targets.iter()) - .chain(index_value_targets) - .chain(index_value_targets) - .chain(iter::once(&index_ids[0])) - .chain(index_value_targets) - .cloned() - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - subtree_proof.to_values_raw(), - &[subtree_proof.num_matching_rows_target()], - subtree_proof.to_ops_raw(), - index_value_targets, - index_value_targets, - index_value_targets, - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[*subtree_proof.to_overflow_raw()], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeIndexLeafWires { - min_query, - max_query, - } - } - - fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeIndexLeafWires) { - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 0 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 1; - -impl CircuitLogicWires - for FullNodeIndexLeafWires -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeIndexLeafCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof. - let subtree_proof = PublicInputs::from_slice(&verified_proofs[0].public_inputs); - - Self::Inputs::build(builder, &subtree_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{aggregation::utils::tests::unify_subtree_proof, pi_len}, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeIndexLeafCircuit<'a> { - c: FullNodeIndexLeafCircuit, - subtree_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeIndexLeafCircuit<'_> { - // Circuit wires + subtree proof - type Wires = (FullNodeIndexLeafWires, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let subtree_proof = b - .add_virtual_target_arr::<{ pi_len::() }>() - .to_vec(); - let subtree_pi = PublicInputs::::from_slice(&subtree_proof); - - let wires = FullNodeIndexLeafCircuit::build(b, &subtree_pi); - - (wires, subtree_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - } - } - - #[test] - fn test_query_agg_full_node_index_leaf() { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the subtree proof. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::(&mut subtree_proof, false, min_query, max_query); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - - let index_value = subtree_pi.index_value(); - let index_value_fields = subtree_pi.to_index_value_raw(); - let index_ids = subtree_pi.index_ids(); - - // Construct the test circuit. - let test_circuit = TestFullNodeIndexLeafCircuit { - c: FullNodeIndexLeafCircuit { - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - // H(H("") || H("") || p.I || p.I || p.index_ids[0] || p.I || p.H)) - let empty_hash = empty_poseidon_hash(); - let empty_hash_fields = empty_hash.to_fields(); - let inputs: Vec<_> = empty_hash_fields - .iter() - .chain(empty_hash_fields.iter()) - .chain(index_value_fields) - .chain(index_value_fields) - .chain(iter::once(&index_ids[0])) - .chain(index_value_fields) - .chain(subtree_pi.to_hash_raw()) - .cloned() - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - assert_eq!(pi.to_values_raw(), subtree_pi.to_values_raw()); - // Count - assert_eq!(pi.num_matching_rows(), subtree_pi.num_matching_rows()); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), index_value); - // Maximum value - assert_eq!(pi.max_value(), index_value); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Overflow flag - assert_eq!(pi.overflow_flag(), subtree_pi.overflow_flag()); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_with_one_child.rs b/verifiable-db/src/query/aggregation/full_node_with_one_child.rs deleted file mode 100644 index 8ac0b9ef1..000000000 --- a/verifiable-db/src/query/aggregation/full_node_with_one_child.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Module handling the full node with one child for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - poseidon::empty_poseidon_hash, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{iter, slice}; - -/// Full node wires with one child -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithOneChildWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithOneChildCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// The flag specified if the child node is the left or right child - pub(crate) is_left_child: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeWithOneChildCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proof: &PublicInputs, - ) -> FullNodeWithOneChildWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let zero = b.zero(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proof. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - slice::from_ref(child_proof), - ); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - let node_min = b.select_u256(is_left_child, &child_proof.min_value_target(), &node_value); - let node_max = b.select_u256(is_left_child, &node_value, &child_proof.max_value_target()); - - // Compute the node hash: - // H(left_child.H || right_child.H || node_min || node_max || column_id || node_value || p.H)) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = hash_maybe_first( - b, - is_left_child, - empty_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = compute_output_item(b, i, &[subtree_proof, child_proof]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = current.count + child.count - let count = b.add( - subtree_proof.num_matching_rows_target(), - child_proof.num_matching_rows_target(), - ); - - // overflow = (pC.overflow + pR.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeWithOneChildWires { - is_rows_tree_node, - is_left_child, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &FullNodeWithOneChildWires, - ) { - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - pw.set_bool_target(wires.is_left_child, self.is_left_child); - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 1 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 2; - -impl CircuitLogicWires - for FullNodeWithOneChildWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeWithOneChildCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the second is the child proof. - let [subtree_proof, child_proof] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeWithOneChildCircuit<'a> { - c: FullNodeWithOneChildCircuit, - subtree_proof: &'a [F], - child_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeWithOneChildCircuit<'_> { - // Circuit wires + subtree proof + child proof - type Wires = ( - FullNodeWithOneChildWires, - Vec, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = FullNodeWithOneChildCircuit::build(b, &subtree_pi, &child_pi); - - let [subtree_proof, child_proof] = proofs; - - (wires, subtree_proof, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.child_proof); - } - } - - fn test_full_node_with_one_child_circuit(is_rows_tree_node: bool, is_left_child: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut child_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - unify_child_proof::( - &mut child_proof, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), node_value] - } else { - [node_value, child_pi.max_value()] - }; - - // Construct the test circuit. - let test_circuit = TestFullNodeWithOneChildCircuit { - c: FullNodeWithOneChildCircuit { - is_rows_tree_node, - is_left_child, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - child_proof: &child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - let empty_hash = empty_poseidon_hash(); - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, *empty_hash] - } else { - [*empty_hash, child_hash] - }; - - // H(left_child.H || right_child.H || node_min || node_max || column_id || node_value || p.H)) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() || child_pi.overflow_flag() || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() + child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_row_node_with_left_child() { - test_full_node_with_one_child_circuit(true, true); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_row_node_with_right_child() { - test_full_node_with_one_child_circuit(true, false); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_index_node_with_left_child() { - test_full_node_with_one_child_circuit(false, true); - } - - #[test] - fn test_query_agg_full_node_with_one_child_for_index_node_with_right_child() { - test_full_node_with_one_child_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/full_node_with_two_children.rs b/verifiable-db/src/query/aggregation/full_node_with_two_children.rs deleted file mode 100644 index 1594e2ecb..000000000 --- a/verifiable-db/src/query/aggregation/full_node_with_two_children.rs +++ /dev/null @@ -1,398 +0,0 @@ -//! Module handling the full node with two children for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::H, - public_inputs::PublicInputCommon, - serialization::{deserialize, serialize}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::iter; - -/// Full node wires with two children -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithTwoChildrenWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FullNodeWithTwoChildrenCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl FullNodeWithTwoChildrenCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proofs: &[PublicInputs; 2], - ) -> FullNodeWithTwoChildrenWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let zero = b.zero(); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let [min_query, max_query] = [0; 2].map(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proofs. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - child_proofs, - ); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - // Compute the node hash: - // node_hash = H(p1.H || p2.H || p1.min || p2.max || column_id || node_value || p.H) - let [child_proof1, child_proof2] = child_proofs; - let inputs = child_proof1 - .tree_hash_target() - .to_targets() - .into_iter() - .chain(child_proof2.tree_hash_target().to_targets()) - .chain(child_proof1.min_value_target().to_targets()) - .chain(child_proof2.max_value_target().to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item(b, i, &[subtree_proof, child_proof1, child_proof2]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = p1.count + p2.count + p.count - let count = b.add( - child_proof1.num_matching_rows_target(), - child_proof2.num_matching_rows_target(), - ); - let count = b.add(count, subtree_proof.num_matching_rows_target()); - - // overflow = (p.overflow + p1.overflow + p2.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof1.to_overflow_raw(), - child_proof2.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - child_proof1.to_min_value_raw(), - child_proof2.to_max_value_raw(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - FullNodeWithTwoChildrenWires { - is_rows_tree_node, - min_query, - max_query, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &FullNodeWithTwoChildrenWires, - ) { - pw.set_bool_target(wires.is_rows_tree_node, self.is_rows_tree_node); - pw.set_u256_target(&wires.min_query, self.min_query); - pw.set_u256_target(&wires.max_query, self.max_query); - } -} - -/// Subtree proof number = 1, child proof number = 2 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 3; - -impl CircuitLogicWires - for FullNodeWithTwoChildrenWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = FullNodeWithTwoChildrenCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the remainings are child proofs. - let [subtree_proof, child_proof1, child_proof2] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &[child_proof1, child_proof2]) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{utils::ToFields, C}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestFullNodeWithTwoChildrenCircuit<'a> { - c: FullNodeWithTwoChildrenCircuit, - subtree_proof: &'a [F], - left_child_proof: &'a [F], - right_child_proof: &'a [F], - } - - impl UserCircuit for TestFullNodeWithTwoChildrenCircuit<'_> { - // Circuit wires + subtree proof + left child proof + right child proof - type Wires = ( - FullNodeWithTwoChildrenWires, - Vec, - Vec, - Vec, - ); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, left_child_pi, right_child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = FullNodeWithTwoChildrenCircuit::build( - b, - &subtree_pi, - &[left_child_pi, right_child_pi], - ); - - let [subtree_proof, left_child_proof, right_child_proof] = proofs; - - (wires, subtree_proof, left_child_proof, right_child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.left_child_proof); - pw.set_target_arr(&wires.3, self.right_child_proof); - } - } - - fn test_full_node_with_two_children_circuit(is_rows_tree_node: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut left_child_proof, mut right_child_proof] = - random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - [&mut left_child_proof, &mut right_child_proof] - .iter_mut() - .for_each(|p| { - unify_child_proof::( - p, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ) - }); - let left_child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&left_child_proof); - let right_child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&right_child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - - // Construct the test circuit. - let test_circuit = TestFullNodeWithTwoChildrenCircuit { - c: FullNodeWithTwoChildrenCircuit { - is_rows_tree_node, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - left_child_proof: &left_child_proof, - right_child_proof: &right_child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(p1.H || p2.H || p1.min || p2.max || column_id || node_value || p.H) - let inputs: Vec<_> = left_child_pi - .tree_hash() - .to_fields() - .into_iter() - .chain(right_child_pi.tree_hash().to_fields()) - .chain(left_child_pi.min_value().to_fields()) - .chain(right_child_pi.max_value().to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &left_child_pi, &right_child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() - || left_child_pi.overflow_flag() - || right_child_pi.overflow_flag() - || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() - + left_child_pi.num_matching_rows() - + right_child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), left_child_pi.min_value()); - // Maximum value - assert_eq!(pi.max_value(), right_child_pi.max_value()); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_full_node_with_two_children_for_row_node() { - test_full_node_with_two_children_circuit(true); - } - - #[test] - fn test_query_agg_full_node_with_two_children_for_index_node() { - test_full_node_with_two_children_circuit(false); - } -} diff --git a/verifiable-db/src/query/aggregation/non_existence_inter.rs b/verifiable-db/src/query/aggregation/non_existence_inter.rs deleted file mode 100644 index c9287d8fc..000000000 --- a/verifiable-db/src/query/aggregation/non_existence_inter.rs +++ /dev/null @@ -1,761 +0,0 @@ -//! Module handling the non-existence intermediate node for query aggregation circuits - -use crate::query::{ - aggregation::output_computation::compute_dummy_output_targets, - public_inputs::PublicInputs, - universal_circuit::universal_query_circuit::{ - QueryBound, QueryBoundTarget, QueryBoundTargetInputs, - }, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - poseidon::{empty_poseidon_hash, H}, - public_inputs::PublicInputCommon, - serialization::{ - deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, - serialize_long_array, - }, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{SelectHashBuilder, ToTargets}, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{array, iter}; - -/// Non-existence intermediate node wires -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NonExistenceInterNodeWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - min_query: QueryBoundTargetInputs, - max_query: QueryBoundTargetInputs, - value: UInt256Target, - index_value: UInt256Target, - index_ids: [Target; 2], - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - ops: [Target; MAX_NUM_RESULTS], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - subtree_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - computational_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - placeholder_hash: HashOutTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_min: UInt256Target, - left_child_max: UInt256Target, - left_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - left_grand_children: [HashOutTarget; 2], - right_child_min: UInt256Target, - right_child_max: UInt256Target, - right_child_value: UInt256Target, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - right_grand_children: [HashOutTarget; 2], - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - left_child_exists: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - right_child_exists: BoolTarget, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NonExistenceInterNodeCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: QueryBound, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: QueryBound, - pub(crate) value: U256, - /// Value of the indexed column for the row stored in the current node - /// (meaningful only if the current node belongs to a rows tree, - /// can be equal to `value` if the current node belongs to the index tree) - pub(crate) index_value: U256, - /// Integer identifiers of the indexed columns - pub(crate) index_ids: [F; 2], - /// Set of identifiers of the aggregation operations for each of the `S` items found in `V` - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - pub(crate) ops: [F; MAX_NUM_RESULTS], - /// Hash of the tree stored in the current node - pub(crate) subtree_hash: HashOut, - /// Computational hash associated to the processing of single rows of the query - /// (meaningless in this case, we just need to provide it for public input compliance) - pub(crate) computational_hash: HashOut, - /// Placeholder hash associated to the processing of single rows of the query - /// (meaningless in this case, we just need to provide it for public input compliance) - pub(crate) placeholder_hash: HashOut, - /// Minimum value associated to the left child - pub(crate) left_child_min: U256, - /// Maximum value associated to the left child - pub(crate) left_child_max: U256, - /// Value stored in the left child - pub(crate) left_child_value: U256, - /// Hashes of the row/rows tree stored in the left child - pub(crate) left_tree_hash: HashOut, - /// Hashes of the children nodes of the left child - pub(crate) left_grand_children: [HashOut; 2], - /// Minimum value associated to the right child - pub(crate) right_child_min: U256, - /// Maximum value associated to the right child - pub(crate) right_child_max: U256, - /// Value stored in the right child - pub(crate) right_child_value: U256, - /// Hashes of the row/rows tree stored in the right child - pub(crate) right_tree_hash: HashOut, - /// Hashes of the children nodes of the right child - pub(crate) right_grand_children: [HashOut; 2], - /// Boolean flag specifying whether there is a left child for the current node - pub(crate) left_child_exists: bool, - /// Boolean flag specifying whether there is a right child for the current node - pub(crate) right_child_exists: bool, -} - -impl NonExistenceInterNodeCircuit { - pub fn build(b: &mut CBuilder) -> NonExistenceInterNodeWires { - let ttrue = b._true(); - let ffalse = b._false(); - let zero = b.zero(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let left_child_exists = b.add_virtual_bool_target_safe(); - let right_child_exists = b.add_virtual_bool_target_safe(); - // Initialize as unsafe, since all these Uint256s are either exposed as - // public inputs or passed as inputs for hash computation. - let [value, index_value, left_child_value, left_child_min, left_child_max, right_child_value, right_child_min, right_child_max] = - b.add_virtual_u256_arr_unsafe(); - // compute min and max query bounds for secondary index - - let index_ids = b.add_virtual_target_arr(); - let ops = b.add_virtual_target_arr(); - let [subtree_hash, computational_hash, placeholder_hash, left_child_subtree_hash, left_grand_child_hash1, left_grand_child_hash2, right_child_subtree_hash, right_grand_child_hash1, right_grand_child_hash2] = - array::from_fn(|_| b.add_virtual_hash()); - - let min_query = QueryBoundTarget::new(b); - let max_query = QueryBoundTarget::new(b); - - let min_query_value = min_query.get_bound_value(); - let max_query_value = max_query.get_bound_value(); - - let [min_query_targets, max_query_targets] = - [&min_query_value, &max_query_value].map(|v| v.to_targets()); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - - // Enforce that the value associated to the current node is out of the range - // specified by the query: - // value < MIN_query OR value > MAX_query - let is_value_less_than_min = b.is_less_than_u256(&value, min_query_value); - let is_value_greater_than_max = b.is_less_than_u256(max_query_value, &value); - let is_out_of_range = b.or(is_value_less_than_min, is_value_greater_than_max); - b.connect(is_out_of_range.target, ttrue.target); - - // Enforce that the records found in the subtree rooted in the child node - // are all out of the range specified by the query. If left child exists, - // ensure left_child_max < MIN_query; if right child exists, ensure right_child_min > MAX_query. - let is_child_less_than_min = b.is_less_than_u256(&left_child_max, min_query_value); - let is_left_child_out_of_range = b.and(left_child_exists, is_child_less_than_min); - b.connect(is_left_child_out_of_range.target, left_child_exists.target); - let is_child_greater_than_max = b.is_less_than_u256(max_query_value, &right_child_min); - let is_right_child_out_of_range = b.and(right_child_exists, is_child_greater_than_max); - b.connect( - is_right_child_out_of_range.target, - right_child_exists.target, - ); - - // Compute dummy values for each of the `S` values to be returned as output. - let outputs = compute_dummy_output_targets(b, &ops); - - // Recompute hash of left child node to bind left_child_min and left_child_max inputs: - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs = left_grand_child_hash1 - .to_targets() - .into_iter() - .chain(left_grand_child_hash2.to_targets()) - .chain(left_child_min.to_targets()) - .chain(left_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_targets()) - .chain(left_child_subtree_hash.to_targets()) - .collect(); - let left_child_hash = b.hash_n_to_hash_no_pad::(inputs); - - let left_child_hash = b.select_hash(left_child_exists, &left_child_hash, &empty_hash); - - // Recompute hash of right child node to bind right_child_min and right_child_max inputs: - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs = right_grand_child_hash1 - .to_targets() - .into_iter() - .chain(right_grand_child_hash2.to_targets()) - .chain(right_child_min.to_targets()) - .chain(right_child_max.to_targets()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_targets()) - .chain(right_child_subtree_hash.to_targets()) - .collect(); - let right_child_hash = b.hash_n_to_hash_no_pad::(inputs); - - let right_child_hash = b.select_hash(right_child_exists, &right_child_hash, &empty_hash); - - // node_min = left_child_exists ? left_child_min : value - let node_min = b.select_u256(left_child_exists, &left_child_min, &value); - // node_max = right_child_exists ? right_child_max : value - let node_max = b.select_u256(right_child_exists, &right_child_max, &value); - let [node_min_targets, node_max_targets] = [node_min, node_max].map(|u| u.to_targets()); - - // Compute the node hash: - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || value || subtree_hash) - let inputs = left_child_hash - .to_targets() - .into_iter() - .chain(right_child_hash.to_targets()) - .chain(node_min_targets.clone()) - .chain(node_max_targets.clone()) - .chain(iter::once(column_id)) - .chain(value.to_targets()) - .chain(subtree_hash.to_targets()) - .collect(); - let node_hash = b.hash_n_to_hash_no_pad::(inputs); - - // We add the query bounds to the placeholder hash only if the current node is in a rows tree. - let placeholder_hash_with_query_bounds = - QueryBoundTarget::add_query_bounds_to_placeholder_hash( - b, - &min_query, - &max_query, - &placeholder_hash, - ); - let new_placeholder_hash = b.select_hash( - is_rows_tree_node, - &placeholder_hash_with_query_bounds, - &placeholder_hash, - ); - // We add the query bounds to the computational hash only if the current - // node is in a rows tree. - let computational_hash_with_query_bounds = - QueryBoundTarget::add_query_bounds_to_computational_hash( - b, - &min_query, - &max_query, - &computational_hash, - ); - let new_computational_hash = b.select_hash( - is_rows_tree_node, - &computational_hash_with_query_bounds, - &computational_hash, - ); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - outputs.as_slice(), - &[zero], - &ops, - &index_value.to_targets(), - &node_min_targets, - &node_max_targets, - &index_ids, - &min_query_targets, - &max_query_targets, - &[ffalse.target], - &new_computational_hash.to_targets(), - &new_placeholder_hash.to_targets(), - ) - .register(b); - - let left_grand_children = [left_grand_child_hash1, left_grand_child_hash2]; - let right_grand_children = [right_grand_child_hash1, right_grand_child_hash2]; - - NonExistenceInterNodeWires { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query: min_query.into(), - max_query: max_query.into(), - value, - index_value, - left_child_value, - left_child_min, - left_child_max, - right_child_value, - right_child_min, - right_child_max, - index_ids, - ops, - subtree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_subtree_hash, - left_grand_children, - right_tree_hash: right_child_subtree_hash, - right_grand_children, - } - } - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &NonExistenceInterNodeWires, - ) { - [ - (wires.is_rows_tree_node, self.is_rows_tree_node), - (wires.left_child_exists, self.left_child_exists), - (wires.right_child_exists, self.right_child_exists), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - [ - (&wires.value, self.value), - (&wires.index_value, self.index_value), - (&wires.left_child_value, self.left_child_value), - (&wires.left_child_min, self.left_child_min), - (&wires.left_child_max, self.left_child_max), - (&wires.right_child_value, self.right_child_value), - (&wires.right_child_min, self.right_child_min), - (&wires.right_child_max, self.right_child_max), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - wires.min_query.assign(pw, &self.min_query); - wires.max_query.assign(pw, &self.max_query); - pw.set_target_arr(&wires.index_ids, &self.index_ids); - pw.set_target_arr(&wires.ops, &self.ops); - [ - (wires.subtree_hash, self.subtree_hash), - (wires.computational_hash, self.computational_hash), - (wires.placeholder_hash, self.placeholder_hash), - (wires.left_tree_hash, self.left_tree_hash), - (wires.right_tree_hash, self.right_tree_hash), - ] - .iter() - .for_each(|(t, v)| pw.set_hash_target(*t, *v)); - wires - .left_grand_children - .iter() - .zip(self.left_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - wires - .right_grand_children - .iter() - .zip(self.right_grand_children) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - } -} - -/// Verified proof number = 0 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 0; - -impl CircuitLogicWires - for NonExistenceInterNodeWires -{ - type CircuitBuilderParams = (); - type Inputs = NonExistenceInterNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - _verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - Self::Inputs::build(builder) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::{ - query::{ - aggregation::{ - output_computation::tests::compute_dummy_output_values, QueryBoundSource, - QueryBounds, - }, - computational_hash_ids::{AggregationOperation, Identifiers}, - universal_circuit::universal_circuit_inputs::{PlaceholderId, Placeholders}, - }, - test_utils::random_aggregation_operations, - }; - use mp2_common::{array::ToField, poseidon::H, utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{ - field::types::{Field, Sample}, - plonk::config::Hasher, - }; - - use rand::{prelude::SliceRandom, thread_rng, Rng}; - - const MAX_NUM_RESULTS: usize = 20; - - impl UserCircuit for NonExistenceInterNodeCircuit { - type Wires = NonExistenceInterNodeWires; - - fn build(b: &mut CBuilder) -> Self::Wires { - NonExistenceInterNodeCircuit::build(b) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.assign(pw, wires); - } - } - - fn test_non_existence_inter_circuit( - is_rows_tree_node: bool, - left_child_exists: bool, - right_child_exists: bool, - ops: [F; MAX_NUM_RESULTS], - ) { - let min_query_value = U256::from(1000); - let max_query_value = U256::from(3000); - - let mut rng = &mut thread_rng(); - // value < MIN_query OR value > MAX_query - let value = *[ - min_query_value - U256::from(1), - max_query_value + U256::from(1), - ] - .choose(&mut rng) - .unwrap(); - let [left_child_min, left_child_max] = if left_child_exists { - // left_child_max < MIN_query - [U256::from_limbs(rng.gen()), min_query_value - U256::from(1)] - } else { - // no constraints otherwise - [U256::from_limbs(rng.gen()), U256::from_limbs(rng.gen())] - }; - let [right_child_min, right_child_max] = if right_child_exists { - // right_child_min > MAX_query - [max_query_value + U256::from(1), U256::from_limbs(rng.gen())] - } else { - // no constraints otherwise - [U256::from_limbs(rng.gen()), U256::from_limbs(rng.gen())] - }; - let [index_value, left_child_value, right_child_value] = - array::from_fn(|_| U256::from_limbs(rng.gen())); - let index_ids = F::rand_array(); - let [subtree_hash, computational_hash, placeholder_hash, left_child_subtree_hash, left_grand_child_hash1, left_grand_child_hash2, right_child_subtree_hash, right_grand_child_hash1, right_grand_child_hash2] = - array::from_fn(|_| gen_random_field_hash()); - let left_grand_children = [left_grand_child_hash1, left_grand_child_hash2]; - let right_grand_children = [right_grand_child_hash1, right_grand_child_hash2]; - - let first_placeholder_id = PlaceholderId::Generic(0); - - let (min_query, max_query, _placeholders) = if is_rows_tree_node { - let dummy_min_query_primary = U256::ZERO; //dummy value, circuit will employ only bounds for secondary index - let dummy_max_query_primary = U256::MAX; //dummy value, circuit will employ only bounds for secondary index - let placeholders = Placeholders::from(( - vec![(first_placeholder_id, max_query_value)], - dummy_min_query_primary, - dummy_max_query_primary, - )); - - let query_bounds = QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(min_query_value)), - Some(QueryBoundSource::Placeholder(first_placeholder_id)), - ) - .unwrap(); - ( - QueryBound::new_secondary_index_bound( - &placeholders, - &query_bounds.min_query_secondary, - ) - .unwrap(), - QueryBound::new_secondary_index_bound( - &placeholders, - &query_bounds.max_query_secondary, - ) - .unwrap(), - placeholders, - ) - } else { - // min_query and max_query should be primary index bounds - let placeholders = Placeholders::new_empty(min_query_value, max_query_value); - ( - QueryBound::new_primary_index_bound(&placeholders, true).unwrap(), - QueryBound::new_primary_index_bound(&placeholders, false).unwrap(), - placeholders, - ) - }; - - // Construct the test circuit. - let test_circuit = NonExistenceInterNodeCircuit { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query: min_query.clone(), - max_query: max_query.clone(), - value, - index_value, - left_child_value, - left_child_min, - left_child_max, - index_ids, - ops, - subtree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_subtree_hash, - left_grand_children, - right_child_value, - right_child_min, - right_child_max, - right_tree_hash: right_child_subtree_hash, - right_grand_children, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // node_min = is_left_child ? child_min : value - // node_max = is_left_child ? value : child_max - let node_min = if left_child_exists { - left_child_min - } else { - value - }; - let node_max = if right_child_exists { - right_child_max - } else { - value - }; - - // Check the public inputs. - // Tree hash - { - let empty_hash = empty_poseidon_hash(); - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs: Vec<_> = left_grand_child_hash1 - .to_fields() - .into_iter() - .chain(left_grand_child_hash2.to_fields()) - .chain(left_child_min.to_fields()) - .chain(left_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(left_child_value.to_fields()) - .chain(left_child_subtree_hash.to_fields()) - .collect(); - let left_child_hash = H::hash_no_pad(&inputs); - - let left_child_hash = if left_child_exists { - left_child_hash - } else { - *empty_hash - }; - - // H(h1 || h2 || child_min || child_max || column_id || child_value || child_subtree_hash) - let inputs: Vec<_> = right_grand_child_hash1 - .to_fields() - .into_iter() - .chain(right_grand_child_hash2.to_fields()) - .chain(right_child_min.to_fields()) - .chain(right_child_max.to_fields()) - .chain(iter::once(column_id)) - .chain(right_child_value.to_fields()) - .chain(right_child_subtree_hash.to_fields()) - .collect(); - let right_child_hash = H::hash_no_pad(&inputs); - - let right_child_hash = if right_child_exists { - right_child_hash - } else { - *empty_hash - }; - - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || value || subtree_hash) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(value.to_fields()) - .chain(subtree_hash.to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values - { - let outputs = compute_dummy_output_values(&ops); - assert_eq!(pi.to_values_raw(), outputs); - } - // Count - assert_eq!(pi.num_matching_rows(), F::ZERO); - // Operation IDs - assert_eq!(pi.operation_ids(), ops); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query_value); - // Maximum query - assert_eq!(pi.max_query_value(), max_query_value); - // overflow_flag - assert!(!pi.overflow_flag()); - // Computational hash - { - let exp_hash = if is_rows_tree_node { - QueryBound::add_secondary_query_bounds_to_computational_hash( - &QueryBoundSource::Constant(min_query_value), - &QueryBoundSource::Placeholder(first_placeholder_id), - &computational_hash, - ) - .unwrap() - } else { - computational_hash - }; - assert_eq!(pi.computational_hash(), exp_hash); - } - // Placeholder hash - { - let exp_hash = if is_rows_tree_node { - QueryBound::add_secondary_query_bounds_to_placeholder_hash( - &min_query, - &max_query, - &placeholder_hash, - ) - } else { - placeholder_hash - }; - - assert_eq!(pi.placeholder_hash(), exp_hash); - } - } - - #[test] - fn test_query_agg_non_existence_inter_for_row_node_and_left_child() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(true, true, false, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_row_node_and_right_child() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(true, false, true, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_index_node_and_left_child() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(false, true, false, ops); - } - - #[test] - fn test_query_agg_non_existence_inter_for_index_node_and_right_child() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(false, false, true, ops); - } - - #[test] - fn test_query_agg_non_existence_for_row_tree_leaf_node() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(true, false, false, ops); - } - - #[test] - fn test_query_agg_non_existence_for_index_tree_leaf_node() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(false, false, false, ops); - } - - #[test] - fn test_query_agg_non_existence_for_row_tree_full_node() { - // Generate the random operations. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - test_non_existence_inter_circuit(true, true, true, ops); - } - - #[test] - fn test_query_agg_non_existence_for_index_tree_full_node() { - // Generate the random operations. - let mut ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - - // Set the first operation to ID for testing the digest. - // The condition of the first aggregation operation ID is not associated - // with the `is_rows_tree_node` and `is_left_child` flag. - ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - test_non_existence_inter_circuit(false, true, true, ops); - } -} diff --git a/verifiable-db/src/query/aggregation/partial_node.rs b/verifiable-db/src/query/aggregation/partial_node.rs deleted file mode 100644 index 5e9119e6f..000000000 --- a/verifiable-db/src/query/aggregation/partial_node.rs +++ /dev/null @@ -1,519 +0,0 @@ -//! Module handling the partial node for query aggregation circuits - -use crate::query::{ - aggregation::{output_computation::compute_output_item, utils::constrain_input_proofs}, - public_inputs::PublicInputs, -}; -use alloy::primitives::U256; -use anyhow::Result; -use mp2_common::{ - hash::hash_maybe_first, - poseidon::H, - public_inputs::PublicInputCommon, - serialization::{deserialize, deserialize_array, serialize, serialize_array}, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::ToTargets, - D, F, -}; -use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::proof::ProofWithPublicInputsTarget, -}; -use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; -use std::{array, iter, slice}; - -/// Partial node wires -/// The constant generic parameter is only used for impl `CircuitLogicWires`. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PartialNodeWires { - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_rows_tree_node: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - is_left_child: BoolTarget, - #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] - sibling_tree_hash: HashOutTarget, - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - sibling_child_hashes: [HashOutTarget; 2], - sibling_value: UInt256Target, - sibling_min: UInt256Target, - sibling_max: UInt256Target, - min_query: UInt256Target, - max_query: UInt256Target, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PartialNodeCircuit { - /// The flag specified if the proof is generated for a node in a rows tree or - /// for a node in the index tree - pub(crate) is_rows_tree_node: bool, - /// The flag indicating if the proven child is the left child or right child - pub(crate) is_left_child: bool, - /// Hash of the rows tree stored in the sibling of the proven child - pub(crate) sibling_tree_hash: HashOut, - /// The child hashes of the proven child's sibling - pub(crate) sibling_child_hashes: [HashOut; 2], - /// Value of the indexed column for the rows tree stored in the sibling of - /// the proven child - pub(crate) sibling_value: U256, - /// Minimum value of the indexed column for the subtree rooted in the sibling - /// of the proven child - pub(crate) sibling_min: U256, - /// Maximum value of the indexed column for the subtree rooted in the sibling - /// of the proven child - pub(crate) sibling_max: U256, - /// Minimum range bound specified in the query for the indexed column - /// It's a range bound for the primary indexed column for index tree, - /// and secondary indexed column for rows tree. - pub(crate) min_query: U256, - /// Maximum range bound specified in the query for the indexed column - pub(crate) max_query: U256, -} - -impl PartialNodeCircuit { - pub fn build( - b: &mut CBuilder, - subtree_proof: &PublicInputs, - child_proof: &PublicInputs, - ) -> PartialNodeWires - where - [(); MAX_NUM_RESULTS - 1]:, - { - let ttrue = b._true(); - let zero = b.zero(); - - let is_rows_tree_node = b.add_virtual_bool_target_safe(); - let is_left_child = b.add_virtual_bool_target_unsafe(); - let [sibling_tree_hash, sibling_child_hash1, sibling_child_hash2] = - array::from_fn(|_| b.add_virtual_hash()); - let [sibling_value, sibling_min, sibling_max, min_query, max_query] = - array::from_fn(|_| b.add_virtual_u256_unsafe()); - - // Check the consistency for the subtree proof and child proof. - constrain_input_proofs( - b, - is_rows_tree_node, - &min_query, - &max_query, - subtree_proof, - slice::from_ref(child_proof), - ); - - // Check that the subtree rooted in sibling node contains only leaves with - // indexed columns values outside the query range. - // If the proved child is the left child, ensure sibling_min > MAX_query, - // otherwise sibling_max < MIN_query. - let is_greater_than_max = b.is_less_than_u256(&max_query, &sibling_min); - let is_less_than_min = b.is_less_than_u256(&sibling_max, &min_query); - let is_out_of_range = b.select( - is_left_child, - is_greater_than_max.target, - is_less_than_min.target, - ); - b.connect(is_out_of_range, ttrue.target); - - // Choose the column ID and node value to be hashed depending on which tree - // the current node belongs to. - let index_ids = subtree_proof.index_ids_target(); - let column_id = b.select(is_rows_tree_node, index_ids[1], index_ids[0]); - let index_value = subtree_proof.index_value_target(); - let node_value = b.select_u256( - is_rows_tree_node, - &subtree_proof.min_value_target(), - &index_value, - ); - - // Recompute the tree hash for the sibling node: - // H(h1 || h2 || sibling_min || sibling_max || column_id || sibling_value || sibling_tree_hash) - let inputs = sibling_child_hash1 - .to_targets() - .into_iter() - .chain(sibling_child_hash2.to_targets()) - .chain(sibling_min.to_targets()) - .chain(sibling_max.to_targets()) - .chain(iter::once(column_id)) - .chain(sibling_value.to_targets()) - .chain(sibling_tree_hash.to_targets()) - .collect(); - let sibling_hash = b.hash_n_to_hash_no_pad::(inputs); - - // node_min = is_left_child ? child.min : sibling_min - let node_min = b.select_u256(is_left_child, &child_proof.min_value_target(), &sibling_min); - // node_max = is_left_child ? sibling_max : child.max - let node_max = b.select_u256(is_left_child, &sibling_max, &child_proof.max_value_target()); - - // Compute the node hash: - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || node_value || p.H) - let rest: Vec<_> = node_min - .to_targets() - .into_iter() - .chain(node_max.to_targets()) - .chain(iter::once(column_id)) - .chain(node_value.to_targets()) - .chain(subtree_proof.tree_hash_target().to_targets()) - .collect(); - let node_hash = hash_maybe_first( - b, - is_left_child, - sibling_hash.elements, - child_proof.tree_hash_target().elements, - &rest, - ); - - // Aggregate the output values of children and the overflow number. - let mut num_overflows = zero; - let mut aggregated_values = vec![]; - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = compute_output_item(b, i, &[subtree_proof, child_proof]); - - aggregated_values.append(&mut output); - num_overflows = b.add(num_overflows, overflow); - } - - // count = p.count + child.count - let count = b.add( - subtree_proof.num_matching_rows_target(), - child_proof.num_matching_rows_target(), - ); - - // overflow = (pC.overflow + pR.overflow + num_overflows) != 0 - let overflow = b.add_many([ - subtree_proof.to_overflow_raw(), - child_proof.to_overflow_raw(), - &num_overflows, - ]); - let overflow = b.is_not_equal(overflow, zero); - - // Register the public inputs. - PublicInputs::<_, MAX_NUM_RESULTS>::new( - &node_hash.to_targets(), - aggregated_values.as_slice(), - &[count], - subtree_proof.to_ops_raw(), - subtree_proof.to_index_value_raw(), - &node_min.to_targets(), - &node_max.to_targets(), - subtree_proof.to_index_ids_raw(), - &min_query.to_targets(), - &max_query.to_targets(), - &[overflow.target], - subtree_proof.to_computational_hash_raw(), - subtree_proof.to_placeholder_hash_raw(), - ) - .register(b); - - let sibling_child_hashes = [sibling_child_hash1, sibling_child_hash2]; - - PartialNodeWires { - is_rows_tree_node, - is_left_child, - sibling_tree_hash, - sibling_child_hashes, - sibling_value, - sibling_min, - sibling_max, - min_query, - max_query, - } - } - - fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - [ - (wires.is_rows_tree_node, self.is_rows_tree_node), - (wires.is_left_child, self.is_left_child), - ] - .iter() - .for_each(|(t, v)| pw.set_bool_target(*t, *v)); - [ - (&wires.sibling_value, self.sibling_value), - (&wires.sibling_min, self.sibling_min), - (&wires.sibling_max, self.sibling_max), - (&wires.min_query, self.min_query), - (&wires.max_query, self.max_query), - ] - .iter() - .for_each(|(t, v)| pw.set_u256_target(t, *v)); - pw.set_hash_target(wires.sibling_tree_hash, self.sibling_tree_hash); - wires - .sibling_child_hashes - .iter() - .zip(self.sibling_child_hashes) - .for_each(|(t, v)| pw.set_hash_target(*t, v)); - } -} - -/// Subtree proof number = 1, child proof number = 1 -pub(crate) const NUM_VERIFIED_PROOFS: usize = 2; - -impl CircuitLogicWires - for PartialNodeWires -where - [(); MAX_NUM_RESULTS - 1]:, -{ - type CircuitBuilderParams = (); - type Inputs = PartialNodeCircuit; - - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); - - fn circuit_logic( - builder: &mut CBuilder, - verified_proofs: [&ProofWithPublicInputsTarget; NUM_VERIFIED_PROOFS], - _builder_parameters: Self::CircuitBuilderParams, - ) -> Self { - // The first one is the subtree proof, and the second is the child proof. - let [subtree_proof, child_proof] = - verified_proofs.map(|p| PublicInputs::from_slice(&p.public_inputs)); - - Self::Inputs::build(builder, &subtree_proof, &child_proof) - } - - fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { - inputs.assign(pw, self); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - query::{ - aggregation::{ - tests::compute_output_item_value, - utils::tests::{unify_child_proof, unify_subtree_proof}, - }, - pi_len, - }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, - }; - use mp2_common::{poseidon::H, utils::ToFields, C}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::gen_random_field_hash, - }; - use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; - use rand::{thread_rng, Rng}; - use std::array; - - const MAX_NUM_RESULTS: usize = 20; - - #[derive(Clone, Debug)] - struct TestPartialNodeCircuit<'a> { - c: PartialNodeCircuit, - subtree_proof: &'a [F], - child_proof: &'a [F], - } - - impl UserCircuit for TestPartialNodeCircuit<'_> { - // Circuit wires + query proof + child proof - type Wires = (PartialNodeWires, Vec, Vec); - - fn build(b: &mut CBuilder) -> Self::Wires { - let proofs = array::from_fn(|_| { - b.add_virtual_target_arr::<{ pi_len::() }>() - .to_vec() - }); - let [subtree_pi, child_pi] = - array::from_fn(|i| PublicInputs::::from_slice(&proofs[i])); - - let wires = PartialNodeCircuit::build(b, &subtree_pi, &child_pi); - - let [subtree_proof, child_proof] = proofs; - - (wires, subtree_proof, child_proof) - } - - fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0); - pw.set_target_arr(&wires.1, self.subtree_proof); - pw.set_target_arr(&wires.2, self.child_proof); - } - } - - fn test_partial_node_circuit(is_rows_tree_node: bool, is_left_child: bool) { - let min_query = U256::from(100); - let max_query = U256::from(200); - - let [sibling_tree_hash, sibling_child_hash1, sibling_child_hash2] = - array::from_fn(|_| gen_random_field_hash()); - - let mut rng = thread_rng(); - let sibling_value = U256::from_limbs(rng.gen()); - let [sibling_min, sibling_max] = if is_left_child { - // sibling_min > MAX_query - [max_query + U256::from(1), U256::from_limbs(rng.gen())] - } else { - [U256::from_limbs(rng.gen()), min_query - U256::from(1)] - }; - - // Generate the input proofs. - let ops: [_; MAX_NUM_RESULTS] = random_aggregation_operations(); - let [mut subtree_proof, mut child_proof] = random_aggregation_public_inputs(&ops); - unify_subtree_proof::( - &mut subtree_proof, - is_rows_tree_node, - min_query, - max_query, - ); - let subtree_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&subtree_proof); - unify_child_proof::( - &mut child_proof, - is_rows_tree_node, - min_query, - max_query, - &subtree_pi, - ); - let child_pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&child_proof); - - // Construct the expected public input values. - let index_ids = subtree_pi.index_ids(); - let index_value = subtree_pi.index_value(); - let node_value = if is_rows_tree_node { - subtree_pi.min_value() - } else { - index_value - }; - let [node_min, node_max] = if is_left_child { - [child_pi.min_value(), sibling_max] - } else { - [sibling_min, child_pi.max_value()] - }; - - // Construct the test circuit. - let sibling_child_hashes = [sibling_child_hash1, sibling_child_hash2]; - let test_circuit = TestPartialNodeCircuit { - c: PartialNodeCircuit { - is_rows_tree_node, - is_left_child, - sibling_tree_hash, - sibling_child_hashes, - sibling_value, - sibling_min, - sibling_max, - min_query, - max_query, - }, - subtree_proof: &subtree_proof, - child_proof: &child_proof, - }; - - // Prove for the test circuit. - let proof = run_circuit::(test_circuit); - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); - - // Check the public inputs. - // Tree hash - { - let column_id = if is_rows_tree_node { - index_ids[1] - } else { - index_ids[0] - }; - - // H(h1 || h2 || sibling_min || sibling_max || column_id || sibling_value || sibling_tree_hash) - let inputs: Vec<_> = sibling_child_hash1 - .to_fields() - .into_iter() - .chain(sibling_child_hash2.to_fields()) - .chain(sibling_min.to_fields()) - .chain(sibling_max.to_fields()) - .chain(iter::once(column_id)) - .chain(sibling_value.to_fields()) - .chain(sibling_tree_hash.to_fields()) - .collect(); - let sibling_hash = H::hash_no_pad(&inputs); - - let child_hash = child_pi.tree_hash(); - let [left_child_hash, right_child_hash] = if is_left_child { - [child_hash, sibling_hash] - } else { - [sibling_hash, child_hash] - }; - - // H(left_child_hash || right_child_hash || node_min || node_max || column_id || node_value || p.H) - let inputs: Vec<_> = left_child_hash - .to_fields() - .into_iter() - .chain(right_child_hash.to_fields()) - .chain(node_min.to_fields()) - .chain(node_max.to_fields()) - .chain(iter::once(column_id)) - .chain(node_value.to_fields()) - .chain(subtree_pi.tree_hash().to_fields()) - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - - assert_eq!(pi.tree_hash(), exp_hash); - } - // Output values and overflow flag - { - let mut num_overflows = 0; - let mut aggregated_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - let (mut output, overflow) = - compute_output_item_value(i, &[&subtree_pi, &child_pi]); - - aggregated_values.append(&mut output); - num_overflows += overflow; - } - - assert_eq!(pi.to_values_raw(), aggregated_values); - assert_eq!( - pi.overflow_flag(), - subtree_pi.overflow_flag() || child_pi.overflow_flag() || num_overflows != 0 - ); - } - // Count - assert_eq!( - pi.num_matching_rows(), - subtree_pi.num_matching_rows() + child_pi.num_matching_rows(), - ); - // Operation IDs - assert_eq!(pi.operation_ids(), subtree_pi.operation_ids()); - // Index value - assert_eq!(pi.index_value(), index_value); - // Minimum value - assert_eq!(pi.min_value(), node_min); - // Maximum value - assert_eq!(pi.max_value(), node_max); - // Index IDs - assert_eq!(pi.index_ids(), index_ids); - // Minimum query - assert_eq!(pi.min_query_value(), min_query); - // Maximum query - assert_eq!(pi.max_query_value(), max_query); - // Computational hash - assert_eq!(pi.computational_hash(), subtree_pi.computational_hash()); - // Placeholder hash - assert_eq!(pi.placeholder_hash(), subtree_pi.placeholder_hash()); - } - - #[test] - fn test_query_agg_partial_node_for_row_node_with_left_child() { - test_partial_node_circuit(true, true); - } - - #[test] - fn test_query_agg_partial_node_for_row_node_with_right_child() { - test_partial_node_circuit(true, false); - } - - #[test] - fn test_query_agg_partial_node_for_index_node_with_left_child() { - test_partial_node_circuit(false, true); - } - - #[test] - fn test_query_agg_partial_node_for_index_node_with_right_child() { - test_partial_node_circuit(false, false); - } -} diff --git a/verifiable-db/src/query/aggregation/utils.rs b/verifiable-db/src/query/aggregation/utils.rs deleted file mode 100644 index e9f1f4454..000000000 --- a/verifiable-db/src/query/aggregation/utils.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! Utility functions for query aggregation circuits - -use crate::query::public_inputs::PublicInputs; -use mp2_common::{ - array::Array, - types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target}, - F, -}; -use plonky2::{ - field::types::Field, - iop::target::{BoolTarget, Target}, -}; - -/// Check the consistency for the subtree proof and child proofs. -pub(crate) fn constrain_input_proofs( - b: &mut CBuilder, - is_rows_tree_node: BoolTarget, - min_query: &UInt256Target, - max_query: &UInt256Target, - subtree_proof: &PublicInputs, - child_proofs: &[PublicInputs], -) { - let ffalse = b._false(); - - let index_ids = subtree_proof.index_ids_target(); - let index_value = subtree_proof.index_value_target(); - - // Ensure the proofs in the same rows tree are employing the same value - // of the primary indexed column: - // is_rows_tree_node == is_rows_tree_node AND p.I == p1.I AND p.I == p2.I ... - let is_equals: Vec<_> = child_proofs - .iter() - .map(|p| b.is_equal_u256(&index_value, &p.index_value_target())) - .collect(); - let is_equal = is_equals - .into_iter() - .fold(is_rows_tree_node, |acc, is_equal| b.and(acc, is_equal)); - b.connect(is_equal.target, is_rows_tree_node.target); - - // Ensure the value of the indexed column for all the records stored in the - // rows tree found in this node is within the range specified by the query: - // NOT(is_rows_tree_node) == NOT(is_row_tree_node) AND p.I >= MIN_query AND p.I <= MAX_query - // And assume: is_out_of_range = p.I < MIN_query OR p.I > MAX_query - // => (1 - is_rows_tree_node) * is_out_of_range = 0 - // => is_out_of_range - is_out_of_range * is_rows_tree_node = 0 - let is_less_than_min = b.is_less_than_u256(&index_value, min_query); - let is_greater_than_max = b.is_less_than_u256(max_query, &index_value); - let is_out_of_range = b.or(is_less_than_min, is_greater_than_max); - let is_false = b.arithmetic( - F::NEG_ONE, - F::ONE, - is_rows_tree_node.target, - is_out_of_range.target, - is_out_of_range.target, - ); - b.connect(is_false, ffalse.target); - - // p.index_ids == p1.index_ids == p2.index_ids ... - let index_ids = Array::from(index_ids); - child_proofs - .iter() - .for_each(|p| index_ids.enforce_equal(b, &Array::from(p.index_ids_target()))); - - // p.C == p1.C == p2.C ... - let computational_hash = subtree_proof.computational_hash_target(); - child_proofs - .iter() - .for_each(|p| b.connect_hashes(computational_hash, p.computational_hash_target())); - - // p.H_p == p1.H_p == p2.H_p = ... - let placeholder_hash = subtree_proof.placeholder_hash_target(); - child_proofs - .iter() - .for_each(|p| b.connect_hashes(placeholder_hash, p.placeholder_hash_target())); - - // MIN_query = p1.MIN_I == p2.MIN_I ... - child_proofs - .iter() - .for_each(|p| b.enforce_equal_u256(min_query, &p.min_query_target())); - - // MAX_query = p1.MAX_I == p2.MAX_I ... - child_proofs - .iter() - .for_each(|p| b.enforce_equal_u256(max_query, &p.max_query_target())); - - // if the subtree proof is generated for a rows tree node, - // the query bounds must be same: - // is_row_tree_node = is_row_tree_node AND MIN_query == p.MIN_I AND MAX_query == p.MAX_I - let is_min_query_equal = b.is_equal_u256(min_query, &subtree_proof.min_query_target()); - let is_max_query_equal = b.is_equal_u256(max_query, &subtree_proof.max_query_target()); - let is_equal = b.and(is_min_query_equal, is_max_query_equal); - let is_equal = b.and(is_equal, is_rows_tree_node); - b.connect(is_equal.target, is_rows_tree_node.target); -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use crate::query::public_inputs::QueryPublicInputs; - use alloy::primitives::U256; - use mp2_common::utils::ToFields; - - /// Assign the subtree proof to make it consistent. - pub(crate) fn unify_subtree_proof( - proof: &mut [F], - is_rows_tree_node: bool, - min_query: U256, - max_query: U256, - ) { - let [index_value_range, min_query_range, max_query_range] = [ - QueryPublicInputs::IndexValue, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - ] - .map(PublicInputs::::to_range); - - if is_rows_tree_node { - // p.MIN_I == MIN_query AND p.MAX_I == MAX_query - proof[min_query_range].copy_from_slice(&min_query.to_fields()); - proof[max_query_range].copy_from_slice(&max_query.to_fields()); - } else { - // p.I >= MIN_query AND p.I <= MAX_query - let index_value: U256 = (min_query + max_query) >> 1; - proof[index_value_range].copy_from_slice(&index_value.to_fields()); - } - } - - /// Assign the child proof to make it consistent. - pub(crate) fn unify_child_proof( - proof: &mut [F], - is_rows_tree_node: bool, - min_query: U256, - max_query: U256, - subtree_pi: &PublicInputs, - ) { - let [index_value_range, min_query_range, max_query_range] = [ - QueryPublicInputs::IndexValue, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - ] - .map(PublicInputs::::to_range); - - // child.MIN_I == MIN_query - // child.MAX_I == MAX_query - proof[min_query_range.clone()].copy_from_slice(&min_query.to_fields()); - proof[max_query_range.clone()].copy_from_slice(&max_query.to_fields()); - - if is_rows_tree_node { - // child.I == p.I - proof[index_value_range.clone()].copy_from_slice(subtree_pi.to_index_value_raw()); - } - } -} diff --git a/verifiable-db/src/query/api.rs b/verifiable-db/src/query/api.rs index 8b65f5297..58e902a9b 100644 --- a/verifiable-db/src/query/api.rs +++ b/verifiable-db/src/query/api.rs @@ -1,62 +1,15 @@ -use std::iter::repeat; +use std::iter::{repeat, repeat_with}; -use crate::query::aggregation::full_node_index_leaf::FullNodeIndexLeafCircuit; +use anyhow::{bail, ensure, Result}; -use super::{ - aggregation::{ - child_proven_single_path_node::{ - ChildProvenSinglePathNodeCircuit, ChildProvenSinglePathNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_CHILD, - }, - embedded_tree_proven_single_path_node::{ - EmbeddedTreeProvenSinglePathNodeCircuit, EmbeddedTreeProvenSinglePathNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_EMBEDDED, - }, - full_node_index_leaf::{FullNodeIndexLeafWires, NUM_VERIFIED_PROOFS as NUM_PROOFS_LEAF}, - full_node_with_one_child::{ - FullNodeWithOneChildCircuit, FullNodeWithOneChildWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_FN1, - }, - full_node_with_two_children::{ - FullNodeWithTwoChildrenCircuit, FullNodeWithTwoChildrenWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_FN2, - }, - non_existence_inter::{ - NonExistenceInterNodeCircuit, NonExistenceInterNodeWires, - NUM_VERIFIED_PROOFS as NUM_PROOFS_NE_INTER, - }, - partial_node::{ - PartialNodeCircuit, PartialNodeWires, NUM_VERIFIED_PROOFS as NUM_PROOFS_PN, - }, - ChildPosition, ChildProof, CommonInputs, NodeInfo, NonExistenceInput, - OneProvenChildNodeInput, QueryBounds, QueryHashNonExistenceCircuits, SinglePathInput, - SubProof, TwoProvenChildNodeInput, - }, - computational_hash_ids::{AggregationOperation, HashPermutation, Output}, - pi_len, - universal_circuit::{ - output_no_aggregation::Circuit as NoAggOutputCircuit, - output_with_aggregation::Circuit as AggOutputCircuit, - universal_circuit_inputs::{ - BasicOperation, PlaceholderId, Placeholders, ResultStructure, RowCells, - }, - universal_query_circuit::{ - placeholder_hash, QueryBound, UniversalCircuitInput, UniversalQueryCircuitInputs, - UniversalQueryCircuitWires, - }, - }, -}; -use alloy::primitives::U256; -use anyhow::{ensure, Result}; use itertools::Itertools; -use log::info; use mp2_common::{ array::ToField, default_config, - poseidon::H, - proof::ProofWithVK, + poseidon::{HashPermutation, H}, + proof::{serialize_proof, ProofWithVK}, types::HashOutput, - utils::{Fieldable, ToFields}, + utils::ToFields, C, D, F, }; use plonky2::{ @@ -65,20 +18,155 @@ use plonky2::{ }; use recursion_framework::{ circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, - framework::{ - prepare_recursive_circuit_for_circuit_set, RecursiveCircuitInfo, RecursiveCircuits, - }, + framework::{prepare_recursive_circuit_for_circuit_set, RecursiveCircuits}, }; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize)] -#[allow(clippy::large_enum_variant)] // we need to clone data if we fix by put variants inside a `Box` +use crate::query::{ + circuits::{ + chunk_aggregation::{ + ChunkAggregationCircuit, ChunkAggregationInputs, ChunkAggregationWires, + }, + non_existence::{NonExistenceCircuit, NonExistenceWires}, + row_chunk_processing::{RowChunkProcessingCircuit, RowChunkProcessingWires}, + }, + computational_hash_ids::{AggregationOperation, ColumnIDs, Identifiers}, + row_chunk_gadgets::row_process_gadget::RowProcessingGadgetInputs, + universal_circuit::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, + output_with_aggregation::Circuit as OutputAggCircuit, + universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure, RowCells}, + }, + utils::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, +}; + +use super::{ + computational_hash_ids::Output, + pi_len, + universal_circuit::{ + universal_circuit_inputs::PlaceholderId, + universal_query_circuit::{ + placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams, + }, + }, +}; + +/// Data structure containing all the information needed to verify the membership of +/// a node in a tree and to compute info about its predecessor/successor +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +pub struct TreePathInputs { + /// Info about the node + pub(crate) node_info: NodeInfo, + /// Info about the nodes in the path from the node up to the root of the tree; The `ChildPosition` refers to + /// the position of the previous node in the path as a child of the current node + pub(crate) path: Vec<(NodeInfo, ChildPosition)>, + /// Hash of the siblings of the nodes in path (except for the root) + pub(crate) siblings: Vec>, + /// Info about the children of the node + pub(crate) children: [Option; 2], +} + +impl TreePathInputs { + /// Instantiate a new instance of `TreePathInputs` for a given node from the following input data: + /// - `node_info`: data about the given node + /// - `path`: data about the nodes in the path from the node up to the root of the tree; + /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node + /// - `siblings`: hash of the siblings of the nodes in the path (except for the root) + /// - `children`: data about the children of the given node + pub fn new( + node_info: NodeInfo, + path: Vec<(NodeInfo, ChildPosition)>, + children: [Option; 2], + ) -> Self { + let siblings = path + .iter() + .map(|(node, child_pos)| { + let sibling_index = match *child_pos { + ChildPosition::Left => 1, + ChildPosition::Right => 0, + }; + Some(HashOutput::from(node.child_hashes[sibling_index])) + }) + .collect_vec(); + Self { + node_info, + path, + siblings, + children, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +/// Data structure containing the information about the paths in both the rows tree +/// and the index tree for a node in a rows tree +pub struct NodePath { + pub(crate) row_tree_path: TreePathInputs, + /// Info about the node of the index tree storing the rows tree containing the row + pub(crate) index_tree_path: TreePathInputs, +} + +impl NodePath { + /// Instantiate a new instance of `NodePath` for a given proven row from the following input data: + /// - `row_path`: path from the node to the root of the rows tree storing the node + /// - `index_path` : path from the index tree node storing the rows tree containing the node, up to the + /// root of the index tree + pub fn new(row_path: TreePathInputs, index_path: TreePathInputs) -> Self { + Self { + row_tree_path: row_path, + index_tree_path: index_path, + } + } +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] +/// Data structure containing the inputs necessary to prove a query for a row +/// of the DB table. +pub struct RowInput { + pub(crate) cells: RowCells, + pub(crate) path: NodePath, +} + +impl RowInput { + /// Initialize `RowInput` from the set of cells of the given row and the path + /// in the tree of the node of the rows tree associated to the given row + pub fn new(cells: &RowCells, path: &NodePath) -> Self { + Self { + cells: cells.clone(), + path: path.clone(), + } + } +} + +#[derive(Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] pub enum CircuitInput< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, -> { +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + RowChunkWithAggregation( + RowChunkProcessingCircuit< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputAggCircuit, + >, + ), + ChunkAggregation(ChunkAggregationInputs), + NonExistence(NonExistenceCircuit), /// Inputs for the universal query circuit UniversalCircuit( UniversalCircuitInput< @@ -88,27 +176,160 @@ pub enum CircuitInput< MAX_NUM_RESULTS, >, ), - /// Inputs for circuits with 2 proven children and a proven embedded tree - TwoProvenChildNode(TwoProvenChildNodeInput), - /// Inputs for circuits proving a node with one proven child and a proven embedded tree - OneProvenChildNode(OneProvenChildNodeInput), - /// Inputs for circuits proving a node with only one proven subtree (either a proven child or the embedded tree) - SinglePath(SinglePathInput), - /// Inputs for circuits to prove non-existence of results for the current query - NonExistence(NonExistenceInput), } impl< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, - > CircuitInput + > + CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, { + /// Construct the input necessary to prove a query over a chunk of rows provided as input. + /// It requires to provide at least 1 row; in case there are no rows to be proven, then + /// `Self::new_non_existence_input` should be used instead + pub fn new_row_chunks_input( + rows: &[RowInput], + predicate_operations: &[BasicOperation], + placeholders: &Placeholders, + query_bounds: &QueryBounds, + results: &ResultStructure, + ) -> Result { + ensure!( + !rows.is_empty(), + "there must be at least 1 row to be proven" + ); + ensure!( + rows.len() <= NUM_ROWS, + format!( + "Found {} rows provided as input, maximum allowed is {NUM_ROWS}", + rows.len() + ) + ); + let column_ids = &rows[0].cells.column_ids(); + ensure!( + rows.iter() + .all(|row| row.cells.column_ids().to_vec() == column_ids.to_vec()), + "Rows provided as input don't have the same column ids", + ); + let row_inputs = rows + .iter() + .map(RowProcessingGadgetInputs::try_from) + .collect::>>()?; + + Ok(Self::RowChunkWithAggregation( + RowChunkProcessingCircuit::new( + row_inputs, + column_ids, + predicate_operations, + placeholders, + query_bounds, + results, + )?, + )) + } + + /// Construct the input necessary to aggregate 2 or more row chunks already proven. + /// It requires at least 2 chunks to be aggregated + pub fn new_chunk_aggregation_input(chunks_proofs: &[Vec]) -> Result { + ensure!( + chunks_proofs.len() >= 2, + "At least 2 chunk proofs must be provided" + ); + // deserialize `chunk_proofs`` and pad to NUM_CHUNKS proofs by replicating the last proof in `chunk_proofs` + let last_proof = chunks_proofs.last().unwrap(); + let proofs = chunks_proofs + .iter() + .map(|p| ProofWithVK::deserialize(p)) + .chain(repeat_with(|| ProofWithVK::deserialize(last_proof))) + .take(NUM_CHUNKS) + .collect::>>()?; + + let num_proofs = chunks_proofs.len(); + + ensure!( + num_proofs <= NUM_CHUNKS, + format!("Found {num_proofs} proofs provided as input, maximum allowed is {NUM_CHUNKS}") + ); + + Ok(Self::ChunkAggregation(ChunkAggregationInputs { + chunk_proofs: proofs.try_into().unwrap(), + circuit: ChunkAggregationCircuit { + num_non_dummy_chunks: num_proofs, + }, + })) + } + + /// Construct the input to prove a query in case there are no rows with a primary index value + /// in the primary query range. The circuit employed to prove the non-existence of such a row + /// requires to provide a specific node of the index tree, as described in the docs + /// https://www.notion.so/lagrangelabs/Batching-Query-10628d1c65a880b1b151d4ac017fa445?pvs=4#10e28d1c65a880498f41cd1cad0c61c3 + pub fn new_non_existence_input( + index_node_path: TreePathInputs, + column_ids: &ColumnIDs, + predicate_operations: &[BasicOperation], + results: &ResultStructure, + placeholders: &Placeholders, + query_bounds: &QueryBounds, + ) -> Result { + let QueryHashNonExistenceCircuits { + computational_hash, + placeholder_hash, + } = QueryHashNonExistenceCircuits::new::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >( + column_ids, + predicate_operations, + results, + placeholders, + query_bounds, + false, + )?; + + let aggregation_operations = results + .aggregation_operations() + .into_iter() + .chain(repeat( + Identifiers::AggregationOperations(AggregationOperation::default()).to_field(), + )) + .take(MAX_NUM_RESULTS) + .collect_vec() + .try_into() + .unwrap(); + + Ok(Self::NonExistence(NonExistenceCircuit::new( + &index_node_path, + column_ids.primary, + aggregation_operations, + computational_hash, + placeholder_hash, + query_bounds, + )?)) + } + pub const fn num_placeholders_ids() -> usize { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) } @@ -135,14 +356,9 @@ where ) -> Result { Ok(CircuitInput::UniversalCircuit( match results.output_variant { - Output::Aggregation => UniversalCircuitInput::new_query_with_agg( - column_cells, - predicate_operations, - placeholders, - is_leaf, - query_bounds, - results, - )?, + Output::Aggregation => bail!( + "Universal query circuit should only be used for queries with no aggregation" + ), Output::NoAggregation => UniversalCircuitInput::new_query_no_agg( column_cells, predicate_operations, @@ -155,145 +371,6 @@ where )) } - /// Initialize input to prove a full node from the following inputs: - /// - `left_child_proof`: proof for the left child of the node being proven - /// - `right_child_proof`: proof for the right child of the node being proven - /// - `embedded_tree_proof`: proof for the embedded tree stored in the full node: can be either the proof for a single - /// row (if proving a rows tree node) of the proof for the root node of a rows tree (if proving an index tree node) - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_full_node( - left_child_proof: Vec, - right_child_proof: Vec, - embedded_tree_proof: Vec, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::TwoProvenChildNode(TwoProvenChildNodeInput { - left_child_proof: ProofWithVK::deserialize(&left_child_proof)?, - right_child_proof: ProofWithVK::deserialize(&right_child_proof)?, - embedded_tree_proof: ProofWithVK::deserialize(&embedded_tree_proof)?, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - - /// Initialize input to prove a partial node from the following inputs: - /// - `proven_child_proof`: Proof for the child being a proven node - /// - `embedded_tree_proof`: Proof for the embedded tree stored in the partial node: can be either the proof - /// for a single row (if proving a rows tree node) of the proof for the root node of a rows - /// tree (if proving an index tree node) - /// - `unproven_child`: Data about the child not being a proven node; if the node has only one child, - /// then, this parameter must be `None` - /// - `proven_child_position`: Enum specifying whether the proven child is the left or right child - /// of the partial node being proven - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_partial_node( - proven_child_proof: Vec, - embedded_tree_proof: Vec, - unproven_child: Option, - proven_child_position: ChildPosition, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::OneProvenChildNode(OneProvenChildNodeInput { - unproven_child, - proven_child_proof: ChildProof { - proof: ProofWithVK::deserialize(&proven_child_proof)?, - child_position: proven_child_position, - }, - embedded_tree_proof: ProofWithVK::deserialize(&embedded_tree_proof)?, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - /// Initialize input to prove a single path node from the following inputs: - /// - `subtree_proof`: Proof of either a child node or of the embedded tree stored in the current node - /// - `left_child`: Data about the left child of the current node, if any; must be `None` if the node has - /// no left child - /// - `right_child`: Data about the right child of the current node, if any; must be `None` if the node has - /// no right child - /// - `node_info`: Data about the current node being proven - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - pub fn new_single_path( - subtree_proof: SubProof, - left_child: Option, - right_child: Option, - node_info: NodeInfo, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - ) -> Result { - Ok(CircuitInput::SinglePath(SinglePathInput { - left_child, - right_child, - node_info, - subtree_proof, - common: CommonInputs::new(is_rows_tree_node, query_bounds), - })) - } - /// Initialize input to prove a node storing a value of the primary or secondary index which - /// is outside of the query bounds, from the following inputs: - /// - `node_info`: Data about the node being proven - /// - `left_child_info`: Data aboout the left child of the node being proven; must be `None` if - /// the node being proven has no left child - /// - `right_child_info`: Data aboout the right child of the node being proven; must be `None` if - /// the node being proven has no right child - /// - `primary_index_value`: Value of the primary index associated to the current node - /// - `index_ids`: Identifiers of the primary and secondary index columns - /// - `aggregation_ops`: Set of aggregation operations employed to aggregate the results of the query - /// - `query_hashes`: Computational hash and placeholder hash associated to the query; can be computed with the `new` - /// method of `QueryHashNonExistenceCircuits` data structure - /// - `is_rows_tree_node`: flag specifying whether the full node belongs to the rows tree or to the index tree - /// - `query_bounds`: bounds on primary and secondary indexes specified in the query - #[allow(clippy::too_many_arguments)] // doesn't make sense to aggregate arguments - pub fn new_non_existence_input( - node_info: NodeInfo, - left_child_info: Option, - right_child_info: Option, - primary_index_value: U256, - index_ids: &[u64; 2], - aggregation_ops: &[AggregationOperation], - query_hashes: QueryHashNonExistenceCircuits, - is_rows_tree_node: bool, - query_bounds: &QueryBounds, - placeholders: &Placeholders, - ) -> Result { - let aggregation_ops = aggregation_ops - .iter() - .map(|op| op.to_field()) - .chain(repeat(AggregationOperation::default().to_field())) - .take(MAX_NUM_RESULTS) - .collect_vec(); - let min_query = if is_rows_tree_node { - QueryBound::new_secondary_index_bound(placeholders, query_bounds.min_query_secondary()) - } else { - QueryBound::new_primary_index_bound(placeholders, true) - }?; - let max_query = if is_rows_tree_node { - QueryBound::new_secondary_index_bound(placeholders, query_bounds.max_query_secondary()) - } else { - QueryBound::new_primary_index_bound(placeholders, false) - }?; - Ok(CircuitInput::NonExistence(NonExistenceInput { - node_info, - left_child_info, - right_child_info, - primary_index_value, - index_ids: index_ids - .iter() - .map(|id| id.to_field()) - .collect_vec() - .try_into() - .unwrap(), - computational_hash: query_hashes.computational_hash(), - placeholder_hash: query_hashes.placeholder_hash(), - aggregation_ops: aggregation_ops.try_into().unwrap(), - is_rows_tree_node, - min_query, - max_query, - })) - } - /// This method returns the ids of the placeholders employed to compute the placeholder hash, /// in the same order, so that those ids can be provided as input to other circuits that need /// to recompute this hash @@ -303,45 +380,14 @@ where placeholders: &Placeholders, query_bounds: &QueryBounds, ) -> Result<[PlaceholderId; 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]> { - let row_cells = &RowCells::default(); - Ok(match results.output_variant { - Output::Aggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - AggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - )?; - circuit.ids_for_placeholder_hash() - } - Output::NoAggregation => { - let circuit = UniversalQueryCircuitInputs::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - NoAggOutputCircuit, - >::new( - row_cells, - predicate_operations, - placeholders, - false, // doesn't matter for placeholder hash computation - query_bounds, - results, - )?; - circuit.ids_for_placeholder_hash() - } - } - .try_into() - .unwrap()) + UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ) } /// Compute the `placeholder_hash` associated to a query @@ -372,1324 +418,218 @@ where ) } } -#[derive(Serialize, Deserialize)] -pub struct Parameters< + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Parameters< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, > where - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { - circuit_with_agg: CircuitWithUniversalVerifier< + row_chunk_agg_circuit: CircuitWithUniversalVerifier< F, C, D, 0, - UniversalQueryCircuitWires< + RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - AggOutputCircuit, + OutputAggCircuit, >, >, - circuit_no_agg: CircuitWithUniversalVerifier< + //ToDo: add row_chunk_circuit for queries without aggregation, once we integrate results tree + aggregation_circuit: CircuitWithUniversalVerifier< F, C, D, - 0, - UniversalQueryCircuitWires< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - NoAggOutputCircuit, - >, - >, - full_node_two_children: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_FN2, - FullNodeWithTwoChildrenWires, - >, - full_node_one_child: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_FN1, - FullNodeWithOneChildWires, - >, - full_node_leaf: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_LEAF, - FullNodeIndexLeafWires, + NUM_CHUNKS, + ChunkAggregationWires, >, - partial_node: - CircuitWithUniversalVerifier>, - single_path_proven_child: CircuitWithUniversalVerifier< + non_existence_circuit: CircuitWithUniversalVerifier< F, C, D, - NUM_PROOFS_CHILD, - ChildProvenSinglePathNodeWires, - >, - single_path_embedded_tree: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_EMBEDDED, - EmbeddedTreeProvenSinglePathNodeWires, + 0, + NonExistenceWires, >, - non_existence_intermediate: CircuitWithUniversalVerifier< - F, - C, - D, - NUM_PROOFS_NE_INTER, - NonExistenceInterNodeWires, + universal_circuit: UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputNoAggCircuit, >, circuit_set: RecursiveCircuits, } -const QUERY_CIRCUIT_SET_SIZE: usize = 10; impl< + const NUM_CHUNKS: usize, + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, const MAX_NUM_COLUMNS: usize, const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, - > Parameters + > + Parameters< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > where - [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_RESULTS - 1]:, - [(); pi_len::()]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); >::HASH_SIZE]:, + [(); pi_len::()]:, { - /// Build `Parameters` for query circuits - pub fn build() -> Self { + const CIRCUIT_SET_SIZE: usize = 3; + + pub(crate) fn build() -> Self { let builder = CircuitWithUniversalVerifierBuilder::() }>::new::( default_config(), - QUERY_CIRCUIT_SET_SIZE, + Self::CIRCUIT_SET_SIZE, ); - info!("Building the query circuits parameters..."); - info!("Building universal circuits..."); - let circuit_with_agg = builder.build_circuit(()); - let circuit_no_agg = builder.build_circuit(()); - info!("Building aggregation circuits.."); - let full_node_two_children = builder.build_circuit(()); - let full_node_one_child = builder.build_circuit(()); - let full_node_leaf = builder.build_circuit(()); - let partial_node = builder.build_circuit(()); - let single_path_proven_child = builder.build_circuit(()); - let single_path_embedded_tree = builder.build_circuit(()); - info!("Building non-existence circuits.."); - let non_existence_intermediate = builder.build_circuit(()); + let row_chunk_agg_circuit = builder.build_circuit(()); + let aggregation_circuit = builder.build_circuit(()); + let non_existence_circuit = builder.build_circuit(()); let circuits = vec![ - prepare_recursive_circuit_for_circuit_set(&circuit_with_agg), - prepare_recursive_circuit_for_circuit_set(&circuit_no_agg), - prepare_recursive_circuit_for_circuit_set(&full_node_two_children), - prepare_recursive_circuit_for_circuit_set(&full_node_one_child), - prepare_recursive_circuit_for_circuit_set(&full_node_leaf), - prepare_recursive_circuit_for_circuit_set(&partial_node), - prepare_recursive_circuit_for_circuit_set(&single_path_proven_child), - prepare_recursive_circuit_for_circuit_set(&single_path_embedded_tree), - prepare_recursive_circuit_for_circuit_set(&non_existence_intermediate), + prepare_recursive_circuit_for_circuit_set(&row_chunk_agg_circuit), + prepare_recursive_circuit_for_circuit_set(&aggregation_circuit), + prepare_recursive_circuit_for_circuit_set(&non_existence_circuit), ]; - let circuit_set = RecursiveCircuits::new(circuits); + let universal_circuit = UniversalQueryCircuitParams::build(default_config()); + Self { - circuit_with_agg, - circuit_no_agg, + row_chunk_agg_circuit, + aggregation_circuit, + non_existence_circuit, + universal_circuit, circuit_set, - full_node_two_children, - full_node_one_child, - full_node_leaf, - partial_node, - single_path_proven_child, - single_path_embedded_tree, - non_existence_intermediate, } } - pub fn generate_proof( + pub(crate) fn generate_proof( &self, input: CircuitInput< + NUM_CHUNKS, + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, >, ) -> Result> { - let proof = ProofWithVK::from(match input { - CircuitInput::UniversalCircuit(input) => match input { - UniversalCircuitInput::QueryWithAgg(input) => ( - self.circuit_set - .generate_proof(&self.circuit_with_agg, [], [], input)?, - self.circuit_with_agg.circuit_data().verifier_only.clone(), - ), - UniversalCircuitInput::QueryNoAgg(input) => ( - self.circuit_set - .generate_proof(&self.circuit_no_agg, [], [], input)?, - self.circuit_no_agg.circuit_data().verifier_only.clone(), - ), - }, - CircuitInput::TwoProvenChildNode(TwoProvenChildNodeInput { - left_child_proof, - right_child_proof, - embedded_tree_proof, - common, - }) => { - let (left_proof, left_vk) = left_child_proof.into(); - let (right_proof, right_vk) = right_child_proof.into(); - let (embedded_proof, embedded_vk) = embedded_tree_proof.into(); - let input = FullNodeWithTwoChildrenCircuit { - is_rows_tree_node: common.is_rows_tree_node, - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.full_node_two_children, - [embedded_proof, left_proof, right_proof], - [&embedded_vk, &left_vk, &right_vk], - input, - )?, - self.full_node_two_children - .circuit_data() - .verifier_only - .clone(), + match input { + CircuitInput::RowChunkWithAggregation(row_chunk_processing_circuit) => { + ProofWithVK::serialize( + &( + self.circuit_set.generate_proof( + &self.row_chunk_agg_circuit, + [], + [], + row_chunk_processing_circuit, + )?, + self.row_chunk_agg_circuit + .circuit_data() + .verifier_only + .clone(), + ) + .into(), ) } - CircuitInput::OneProvenChildNode(OneProvenChildNodeInput { - unproven_child, - proven_child_proof, - embedded_tree_proof, - common, - }) => { - let ChildProof { - proof, - child_position, - } = proven_child_proof; - let (child_proof, child_vk) = proof.into(); - let (embedded_proof, embedded_vk) = embedded_tree_proof.into(); - match unproven_child { - Some(child_node) => { - // the node has 2 children, so we use the partial node circuit - let input = PartialNodeCircuit { - is_rows_tree_node: common.is_rows_tree_node, - is_left_child: child_position.to_flag(), - sibling_tree_hash: child_node.embedded_tree_hash, - sibling_child_hashes: child_node.child_hashes, - sibling_value: child_node.value, - sibling_min: child_node.min, - sibling_max: child_node.max, - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.partial_node, - [embedded_proof, child_proof], - [&embedded_vk, &child_vk], - input, - )?, - self.partial_node.get_verifier_data().clone(), - ) - } - None => { - // the node has 1 child, so use the circuit for full node with 1 child - let input = FullNodeWithOneChildCircuit { - is_rows_tree_node: common.is_rows_tree_node, - is_left_child: child_position.to_flag(), - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.full_node_one_child, - [embedded_proof, child_proof], - [&embedded_vk, &child_vk], - input, - )?, - self.full_node_one_child.get_verifier_data().clone(), - ) - } - } - } - CircuitInput::SinglePath(SinglePathInput { - left_child, - right_child, - node_info, - subtree_proof, - common, - }) => { - let left_child_exists = left_child.is_some(); - let right_child_exists = right_child.is_some(); - let left_child_data = left_child.unwrap_or_default(); - let right_child_data = right_child.unwrap_or_default(); - - match subtree_proof { - SubProof::Embedded(input_proof) => { - let (proof, vk) = input_proof.into(); - if !(left_child_exists || right_child_exists) { - // leaf node, so call full node circuit for leaf node - ensure!(!common.is_rows_tree_node, "providing single-path input for a rows tree node leaf, call universal circuit instead"); - let input = FullNodeIndexLeafCircuit { - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.full_node_leaf, - [proof], - [&vk], - input, - )?, - self.full_node_leaf.get_verifier_data().clone(), - ) - } else { - // the input proof refers to the embedded tree stored in the node - let input = EmbeddedTreeProvenSinglePathNodeCircuit { - left_child_min: left_child_data.min, - left_child_max: left_child_data.max, - left_child_value: left_child_data.value, - left_tree_hash: left_child_data.embedded_tree_hash, - left_grand_children: left_child_data.child_hashes, - right_child_min: right_child_data.min, - right_child_max: right_child_data.max, - right_child_value: right_child_data.value, - right_tree_hash: right_child_data.embedded_tree_hash, - right_grand_children: right_child_data.child_hashes, - left_child_exists, - right_child_exists, - is_rows_tree_node: common.is_rows_tree_node, - min_query: common.min_query, - max_query: common.max_query, - }; - ( - self.circuit_set.generate_proof( - &self.single_path_embedded_tree, - [proof], - [&vk], - input, - )?, - self.single_path_embedded_tree.get_verifier_data().clone(), - ) - } - } - SubProof::Child(ChildProof { - proof, - child_position, - }) => { - // the input proof refers to a child of the node - let (proof, vk) = proof.into(); - let is_left_child = child_position.to_flag(); - let input = ChildProvenSinglePathNodeCircuit { - value: node_info.value, - subtree_hash: node_info.embedded_tree_hash, - sibling_hash: if is_left_child { - node_info.child_hashes[1] // set the hash of the right child, since proven child is left - } else { - node_info.child_hashes[0] // set the hash of the left child, since proven child is right - }, - is_left_child, - unproven_min: node_info.min, - unproven_max: node_info.max, - is_rows_tree_node: common.is_rows_tree_node, - }; - ( - self.circuit_set.generate_proof( - &self.single_path_proven_child, - [proof], - [&vk], - input, - )?, - self.single_path_proven_child.get_verifier_data().clone(), - ) - } - } + CircuitInput::ChunkAggregation(chunk_aggregation_inputs) => { + let ChunkAggregationInputs { + chunk_proofs, + circuit, + } = chunk_aggregation_inputs; + let input_vd = chunk_proofs + .iter() + .map(|p| p.verifier_data()) + .cloned() + .collect_vec(); + let input_proofs = chunk_proofs.map(|p| p.proof); + ProofWithVK::serialize( + &( + self.circuit_set.generate_proof( + &self.aggregation_circuit, + input_proofs, + input_vd.iter().collect_vec().try_into().unwrap(), + circuit, + )?, + self.aggregation_circuit + .circuit_data() + .verifier_only + .clone(), + ) + .into(), + ) } - CircuitInput::NonExistence(NonExistenceInput { - node_info, - left_child_info, - right_child_info, - primary_index_value, - index_ids, - computational_hash, - placeholder_hash, - aggregation_ops, - is_rows_tree_node, - min_query, - max_query, - }) => { - // intermediate node - let left_child_exists = left_child_info.is_some(); - let right_child_exists = right_child_info.is_some(); - let left_child_data = left_child_info.unwrap_or_default(); - let right_child_data = right_child_info.unwrap_or_default(); - let input = NonExistenceInterNodeCircuit { - is_rows_tree_node, - left_child_exists, - right_child_exists, - min_query, - max_query, - value: node_info.value, - index_value: primary_index_value, - left_child_value: left_child_data.value, - left_child_min: left_child_data.min, - left_child_max: left_child_data.max, - right_child_value: right_child_data.value, - right_child_min: right_child_data.min, - right_child_max: right_child_data.max, - index_ids, - ops: aggregation_ops, - subtree_hash: node_info.embedded_tree_hash, - computational_hash, - placeholder_hash, - left_tree_hash: left_child_data.embedded_tree_hash, - left_grand_children: left_child_data.child_hashes, - right_tree_hash: right_child_data.embedded_tree_hash, - right_grand_children: right_child_data.child_hashes, - }; - ( + CircuitInput::NonExistence(non_existence_circuit) => ProofWithVK::serialize( + &( self.circuit_set.generate_proof( - &self.non_existence_intermediate, + &self.non_existence_circuit, [], [], - input, + non_existence_circuit, )?, - self.non_existence_intermediate.get_verifier_data().clone(), + self.non_existence_circuit + .circuit_data() + .verifier_only + .clone(), ) + .into(), + ), + CircuitInput::UniversalCircuit(universal_circuit_input) => { + if let UniversalCircuitInput::QueryNoAgg(input) = universal_circuit_input { + serialize_proof(&self.universal_circuit.generate_proof(&input)?) + } else { + unreachable!("Universal circuit should only be used for queries with no aggregation operations") + } } - }); - - proof.serialize() + } } pub(crate) fn get_circuit_set(&self) -> &RecursiveCircuits { &self.circuit_set } -} - -#[cfg(test)] -mod tests { - use std::cmp::Ordering; - - use alloy::primitives::U256; - use itertools::Itertools; - use mp2_common::{proof::ProofWithVK, types::HashOutput, utils::Fieldable, F}; - use mp2_test::utils::{gen_random_field_hash, gen_random_u256}; - use plonky2::{ - field::types::{PrimeField64, Sample}, - plonk::config::GenericHashOut, - }; - use rand::{thread_rng, Rng}; - - use crate::query::{ - aggregation::{ - ChildPosition, NodeInfo, QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits, - SubProof, - }, - api::{CircuitInput, Parameters}, - computational_hash_ids::{ - AggregationOperation, ColumnIDs, Operation, PlaceholderIdentifier, - }, - public_inputs::PublicInputs, - universal_circuit::universal_circuit_inputs::{ - BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, - RowCells, - }, - }; - - #[test] - fn test_api() { - // Simple query for testing SELECT SUM(C1 + C3) FROM T WHERE C3 >= 5 AND C1 > 56 AND C1 <= 67 AND C2 > 34 AND C2 <= $1 - let rng = &mut thread_rng(); - const NUM_COLUMNS: usize = 3; - const MAX_NUM_COLUMNS: usize = 20; - const MAX_NUM_PREDICATE_OPS: usize = 20; - const MAX_NUM_RESULT_OPS: usize = 20; - const MAX_NUM_RESULTS: usize = 10; - - let column_ids = ColumnIDs::new( - F::rand().to_canonical_u64(), - F::rand().to_canonical_u64(), - (0..NUM_COLUMNS - 2) - .map(|_| F::rand().to_canonical_u64()) - .collect_vec(), - ); - - let primary_index_id: F = column_ids.primary; - let secondary_index_id: F = column_ids.secondary; - - let min_query_primary = 57; - let max_query_primary = 67; - let min_query_secondary = 35; - let max_query_secondary = 78; - // define Enum to specify whether to generate index values in range or not - enum IndexValueBounds { - InRange, // generate index value within query bounds - Smaller, // generate index value smaller than minimum query bound - Bigger, // generate inde value bigger than maximum query bound - } - // generate a new row with `NUM_COLUMNS` where value of secondary index is within the query bounds - let mut gen_row = |primary_index: usize, secondary_index: IndexValueBounds| { - (0..NUM_COLUMNS) - .map(|i| match i { - 0 => U256::from(primary_index), - 1 => match secondary_index { - IndexValueBounds::InRange => { - U256::from(rng.gen_range(min_query_secondary..max_query_secondary)) - } - IndexValueBounds::Smaller => { - U256::from(rng.gen_range(0..min_query_secondary)) - } - IndexValueBounds::Bigger => { - U256::from(rng.gen_range(0..min_query_secondary)) - } - }, - _ => gen_random_u256(rng), - }) - .collect_vec() - }; - - let predicate_operations = vec![BasicOperation { - first_operand: InputOperand::Column(2), - second_operand: Some(InputOperand::Constant(U256::from(5))), - op: Operation::GreaterThanOrEqOp, - }]; - let result_operations = vec![BasicOperation { - first_operand: InputOperand::Column(0), - second_operand: Some(InputOperand::Column(2)), - op: Operation::AddOp, - }]; - let aggregation_op_ids = vec![AggregationOperation::SumOp.to_id()]; - let output_items = vec![OutputItem::ComputedValue(0)]; - let results = ResultStructure::new_for_query_with_aggregation( - result_operations, - output_items, - aggregation_op_ids.clone(), - ) - .unwrap(); - let first_placeholder_id = PlaceholderIdentifier::Generic(0); - let placeholders = Placeholders::from(( - vec![(first_placeholder_id, U256::from(max_query_secondary))], - U256::from(min_query_primary), - U256::from(max_query_primary), - )); - let query_bounds = QueryBounds::new( - &placeholders, - Some(QueryBoundSource::Constant(U256::from(min_query_secondary))), - Some(QueryBoundSource::Placeholder(first_placeholder_id)), - ) - .unwrap(); - - let mut params = Parameters::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >::build(); - - // Test serialization of params - let serialized_params = bincode::serialize(¶ms).unwrap(); - // use deserialized params to generate proofs - params = bincode::deserialize(&serialized_params).unwrap(); - - type Input = CircuitInput< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >; - - // test an index tree with all proven nodes: we assume to have index tree built as follows - // (node identified according to their sorting order): - // 4 - // 0 - // 2 - // 1 3 - - // build a vector of 5 rows with values of index columns within the query bounds. The entries in the - // vector are sorted according to primary index value - let column_values = (min_query_primary..max_query_primary) - .step_by((max_query_primary - min_query_primary) / 5) - .take(5) - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - - // generate proof with universal for a row with the `values` provided as input. - // The flag `is_leaf` specifies whether the row is stored in a leaf node of a rows tree - // or not - let gen_universal_circuit_proofs = |values: &[U256], is_leaf: bool| { - let column_cells = values - .iter() - .zip(column_ids.to_vec().iter()) - .map(|(&value, &id)| ColumnCell::new(id.to_canonical_u64(), value)) - .collect_vec(); - let row_cells = RowCells::new( - column_cells[0].clone(), - column_cells[1].clone(), - column_cells[2..].to_vec(), - ); - let input = Input::new_universal_circuit( - &row_cells, - &predicate_operations, - &results, - &placeholders, - is_leaf, - &query_bounds, - ) - .unwrap(); - params.generate_proof(input).unwrap() - }; - - // generate base proofs with universal circuits for each node - let base_proofs = column_values - .iter() - .map(|values| gen_universal_circuit_proofs(values, true)) - .collect_vec(); - - // closure to extract the tree hash from a proof - let get_tree_hash_from_proof = |proof: &[u8]| { - let (proof, _) = ProofWithVK::deserialize(proof).unwrap().into(); - let pis = PublicInputs::::from_slice(&proof.public_inputs); - pis.tree_hash() - }; - - // closure to generate the proof for a leaf node of the index tree, corresponding to the node_index-th row - let gen_leaf_proof_for_node = |node_index: usize| { - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[node_index]); - let node_info = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - None, - column_values[node_index][0], // primary index value for this row - column_values[node_index][0], - column_values[node_index][0], - ); - let tree_hash = node_info.compute_node_hash(primary_index_id); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[node_index].clone()).unwrap(); - let input = Input::new_single_path( - subtree_proof, - None, - None, - node_info, - false, // index tree node - &query_bounds, - ) - .unwrap(); - let proof = params.generate_proof(input).unwrap(); - // check tree hash is correct - assert_eq!(tree_hash, get_tree_hash_from_proof(&proof)); - proof - }; - - // generate proof for node 1 of index tree above - let leaf_proof_left = gen_leaf_proof_for_node(1); - - // generate proof for node 3 of index tree above - let leaf_proof_right = gen_leaf_proof_for_node(3); - - // generate proof for node 2 of index tree above - let left_child_hash = get_tree_hash_from_proof(&leaf_proof_left); - let right_child_hash = get_tree_hash_from_proof(&leaf_proof_right); - let input = Input::new_full_node( - leaf_proof_left, - leaf_proof_right, - base_proofs[2].clone(), - false, - &query_bounds, - ) - .unwrap(); - let full_node_proof = params.generate_proof(input).unwrap(); - - // verify hash is correct - let full_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[2]).to_bytes()).unwrap(), - Some(&HashOutput::try_from(left_child_hash.to_bytes()).unwrap()), - Some(&HashOutput::try_from(right_child_hash.to_bytes()).unwrap()), - column_values[2][0], // primary index value for that row - column_values[1][0], // primary index value for the min node in the left subtree - column_values[3][0], // primary index value for the max node in the right subtree - ); - let full_node_hash = get_tree_hash_from_proof(&full_node_proof); - assert_eq!( - full_node_hash, - full_node_info.compute_node_hash(primary_index_id), - ); - - // generate proof for node 0 of the index tree above - let input = Input::new_partial_node( - full_node_proof, - base_proofs[0].clone(), - None, // there is no left child - ChildPosition::Right, // proven child is the right child of node 0 - false, - &query_bounds, - ) - .unwrap(); - let one_child_node_proof = params.generate_proof(input).unwrap(); - // verify hash is correct - let one_child_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[0]).to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(full_node_hash.to_bytes()).unwrap()), - column_values[0][0], - column_values[0][0], - column_values[3][0], - ); - let one_child_node_hash = get_tree_hash_from_proof(&one_child_node_proof); - assert_eq!( - one_child_node_hash, - one_child_node_info.compute_node_hash(primary_index_id) - ); - - // generate proof for root node - let input = Input::new_partial_node( - one_child_node_proof, - base_proofs[4].clone(), - None, // there is no right child - ChildPosition::Left, // proven child is the left child of root node - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - // check some public inputs for root proof - let check_pis = |root_proof_pis: &[F], node_info: NodeInfo, column_values: &[Vec]| { - let pis = PublicInputs::::from_slice(root_proof_pis); - assert_eq!( - pis.tree_hash(), - node_info.compute_node_hash(primary_index_id), - ); - assert_eq!(pis.min_value(), node_info.min,); - assert_eq!(pis.max_value(), node_info.max,); - assert_eq!(pis.min_query_value(), query_bounds.min_query_primary()); - assert_eq!(pis.max_query_value(), query_bounds.max_query_primary()); - assert_eq!( - pis.index_ids().to_vec(), - vec![column_ids.primary, column_ids.secondary,], - ); - // compute output value: SUM(C1 + C3) for all the rows where C3 >= 5 - let (output, overflow, count) = - column_values - .iter() - .fold((U256::ZERO, false, 0u64), |acc, value| { - if value[2] >= U256::from(5) - && value[0] >= query_bounds.min_query_primary() - && value[0] <= query_bounds.max_query_primary() - && value[1] >= query_bounds.min_query_secondary().value - && value[1] <= query_bounds.max_query_secondary().value - { - let (sum, overflow) = value[0].overflowing_add(value[2]); - let new_overflow = acc.1 || overflow; - let (new_sum, overflow) = sum.overflowing_add(acc.0); - (new_sum, new_overflow || overflow, acc.2 + 1) - } else { - acc - } - }); - assert_eq!(pis.first_value_as_u256(), output,); - assert_eq!(pis.overflow_flag(), overflow,); - assert_eq!(pis.num_matching_rows(), count.to_field(),); - }; - - let root_node_info = NodeInfo::new( - &HashOutput::try_from(get_tree_hash_from_proof(&base_proofs[4]).to_bytes()).unwrap(), - Some(&HashOutput::try_from(one_child_node_hash.to_bytes()).unwrap()), - None, - column_values[4][0], - column_values[0][0], - column_values[4][0], - ); - - check_pis(&root_proof.public_inputs, root_node_info, &column_values); - - // build an index tree with a mix of proven and unproven nodes. The tree is built as follows: - // 0 - // 8 - // 3 9 - // 2 5 - // 1 4 6 - // 7 - // nodes 3,4,5,6 are in the range specified by the query for the primary index, while the other nodes - // are not - let column_values = [0, min_query_primary / 3, min_query_primary * 2 / 3] - .into_iter() // primary index values for nodes 0,1,2 - .chain( - (min_query_primary..max_query_primary) - .step_by((max_query_primary - min_query_primary) / 4) - .take(4), - ) // primary index values for nodes in the range - .chain([ - max_query_primary * 2, - max_query_primary * 3, - max_query_primary * 4, - ]) // primary index values for nodes 7,8, 9 - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - // generate base proofs with universal circuits for each node in the range - const START_NODE_IN_RANGE: usize = 3; - const LAST_NODE_IN_RANGE: usize = 6; - let base_proofs = column_values[START_NODE_IN_RANGE..=LAST_NODE_IN_RANGE] - .iter() - .map(|values| gen_universal_circuit_proofs(values, true)) - .collect_vec(); - - // generate proof for node 4 - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[4 - START_NODE_IN_RANGE]); - let node_info = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - None, - column_values[4][0], - column_values[4][0], - column_values[4][0], - ); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[4 - START_NODE_IN_RANGE].clone()) - .unwrap(); - let hash_4 = node_info.compute_node_hash(primary_index_id); - let input = - Input::new_single_path(subtree_proof, None, None, node_info, false, &query_bounds) - .unwrap(); - let proof_4 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_4, get_tree_hash_from_proof(&proof_4),); - - // generate proof for node 6 - // compute node data for node 7, which is needed as input to generate the proof - let node_info_7 = NodeInfo::new( - // for the sake of this test, we can use random hash for the embedded tree stored in node 7, since it's not proven; - // in a non-test scenario, we would need to get the actual embedded hash of the node, otherwise the root hash of the - // tree computed in the proofs will be incorrect - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[7][0], - column_values[7][0], - column_values[7][0], - ); - let hash_7 = node_info_7.compute_node_hash(primary_index_id); - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[6 - START_NODE_IN_RANGE]); - let node_info_6 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_7.to_bytes()).unwrap()), - column_values[6][0], - column_values[6][0], - column_values[7][0], - ); - let subtree_proof = - SubProof::new_embedded_tree_proof(base_proofs[6 - START_NODE_IN_RANGE].clone()) - .unwrap(); - let hash_6 = node_info_6.compute_node_hash(primary_index_id); - let input = Input::new_single_path( - subtree_proof, - None, - Some(node_info_7), - node_info_6, - false, - &query_bounds, - ) - .unwrap(); - let proof_6 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_6, get_tree_hash_from_proof(&proof_6)); - - // generate proof for node 5 - let input = Input::new_full_node( - proof_4, - proof_6, - base_proofs[5 - START_NODE_IN_RANGE].clone(), - false, - &query_bounds, - ) - .unwrap(); - let proof_5 = params.generate_proof(input).unwrap(); - // check hash - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[5 - START_NODE_IN_RANGE]); - let node_info_5 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_4.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_6.to_bytes()).unwrap()), - column_values[5][0], - column_values[4][0], - column_values[7][0], - ); - let hash_5 = node_info_5.compute_node_hash(primary_index_id); - assert_eq!(hash_5, get_tree_hash_from_proof(&proof_5),); - - // generate proof for node 3 - // compute node data for node 2, which is needed as input to generate the proof - let node_info_2 = NodeInfo::new( - // same as for node_info_7, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap()), - None, - column_values[2][0], - column_values[1][0], - column_values[2][0], - ); - let hash_2 = node_info_2.compute_node_hash(primary_index_id); - let input = Input::new_partial_node( - proof_5, - base_proofs[3 - START_NODE_IN_RANGE].clone(), - Some(node_info_2), - ChildPosition::Right, // proven child is right child - false, - &query_bounds, - ) - .unwrap(); - let proof_3 = params.generate_proof(input).unwrap(); - // check hash - let embedded_tree_hash = get_tree_hash_from_proof(&base_proofs[3 - START_NODE_IN_RANGE]); - let node_info_3 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_2.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_5.to_bytes()).unwrap()), - column_values[3][0], - column_values[1][0], - column_values[7][0], - ); - let hash_3 = node_info_3.compute_node_hash(primary_index_id); - assert_eq!(hash_3, get_tree_hash_from_proof(&proof_3),); - - // generate proof for node 8 - // compute node_info_9, which is needed as input for the proof - let node_info_9 = NodeInfo::new( - // same as for node_info_2, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[9][0], - column_values[9][0], - column_values[9][0], - ); - let hash_9 = node_info_9.compute_node_hash(primary_index_id); - let node_info_8 = NodeInfo::new( - // same as for node_info_2, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_3.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_9.to_bytes()).unwrap()), - column_values[8][0], - column_values[1][0], - column_values[9][0], - ); - let hash_8 = node_info_8.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_child_proof( - proof_3, - ChildPosition::Left, // subtree proof refers to the left child of the node - ) - .unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_3), - Some(node_info_9), - node_info_8, - false, - &query_bounds, - ) - .unwrap(); - let proof_8 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(get_tree_hash_from_proof(&proof_8), hash_8); - println!("generate proof for node 0"); - - // generate proof for node 0 (root) - let node_info_0 = NodeInfo::new( - // same as for node_info_1, we can use random hashes for the sake of this test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_8.to_bytes()).unwrap()), - column_values[0][0], - column_values[0][0], - column_values[9][0], - ); - let subtree_proof = SubProof::new_child_proof( - proof_8, - ChildPosition::Right, // subtree proof refers to the right child of the node - ) - .unwrap(); - let input = Input::new_single_path( - subtree_proof, - None, - Some(node_info_8), - node_info_0, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - // check some public inputs - check_pis(&root_proof.public_inputs, node_info_0, &column_values); - - // build an index tree with all nodes outside of the primary index range. The tree is built as follows: - // 2 - // 1 3 - // 0 - // where nodes 0 stores an index value smaller than `min_query_primary`, while nodes 1, 2, 3 store index values - // bigger than `max_query_primary` - let column_values = [min_query_primary / 2] - .into_iter() - .chain([ - max_query_primary * 2, - max_query_primary * 3, - max_query_primary * 4, - ]) - .map(|index| gen_row(index, IndexValueBounds::InRange)) - .collect_vec(); - - // generate proof for node 0 with non-existence circuit, since it is outside of the query bounds - let node_info_0 = NodeInfo::new( - // we can use a randomly generated hash for the subtree, for the sake of the test - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[0][0], - column_values[0][0], - column_values[0][0], - ); - let hash_0 = node_info_0.compute_node_hash(primary_index_id); - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - false, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_0, - None, - None, - node_info_0.value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - false, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_0 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_0, get_tree_hash_from_proof(&proof_0),); - - // get up to the root of the tree with proofs - // generate proof for node 1 - let node_info_1 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_0.to_bytes()).unwrap()), - None, - column_values[1][0], - column_values[0][0], - column_values[1][0], - ); - let hash_1 = node_info_1.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_child_proof(proof_0, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_0), - None, - node_info_1, - false, - &query_bounds, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for root node - let node_info_2 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_1.to_bytes()).unwrap()), - None, - column_values[2][0], - column_values[0][0], - column_values[2][0], - ); - let node_info_3 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[3][0], - column_values[3][0], - column_values[3][0], - ); - let subtree_proof = SubProof::new_child_proof(proof_1, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_1), - Some(node_info_3), - node_info_2, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - check_pis(&root_proof.public_inputs, node_info_2, &column_values); - - // generate non-existence proof starting from intermediate node (i.e., node 1) rather than a leaf node - // generate proof with non-existence circuit for node 1 - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - false, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_1, - Some(node_info_0), // node 0 is the left child - None, - node_info_1.value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - false, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for root node - let subtree_proof = SubProof::new_child_proof(proof_1, ChildPosition::Left).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_1), - Some(node_info_3), - node_info_2, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - check_pis(&root_proof.public_inputs, node_info_2, &column_values); - - // generate a tree with rows tree with more than one node. We generate an index tree with 2 nodes A and B, - // both storing a primary index value within the query bounds. - // Node A stores a rows tree with all entries outside of query bounds for secondary index, while - // node B stores a rows tree with all entries within query bounds for secondary index. - // The tree is structured as follows: - // B - // 4 - // 3 5 - // A - // 1 - // 0 2 - let mut column_values = vec![ - gen_row(min_query_primary, IndexValueBounds::Smaller), - gen_row(min_query_primary, IndexValueBounds::Smaller), - gen_row(min_query_primary, IndexValueBounds::Bigger), - gen_row(max_query_primary, IndexValueBounds::InRange), - gen_row(max_query_primary, IndexValueBounds::InRange), - gen_row(max_query_primary, IndexValueBounds::InRange), - ]; - // sort column values according to primary/secondary index values - column_values.sort_by(|a, b| match a[0].cmp(&b[0]) { - Ordering::Less => Ordering::Less, - Ordering::Greater => Ordering::Greater, - Ordering::Equal => a[1].cmp(&b[1]), - }); - - // generate proof for node A rows tree - // generate non-existence proof for node 2, which is the smallest node higher than the maximum query bound, since - // node 1, which is the highest node smaller than the minimum query bound, has 2 children - // (see non-existence circuit docs to see why we don't generate non-existence proofs for nodes with 2 children) - let node_info_2 = NodeInfo::new( - // we can use a random hash for the embedded tree - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[2][1], - column_values[2][1], - column_values[2][1], - ); - let hash_2 = node_info_2.compute_node_hash(secondary_index_id); - - // compute hashes associated to query, which are needed as inputs - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_RESULTS, - >( - &column_ids, - &predicate_operations, - &results, - &placeholders, - &query_bounds, - true, - ) - .unwrap(); - let input = Input::new_non_existence_input( - node_info_2, - None, - None, - column_values[2][0], // we need to place the primary index value associated to this row - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &[AggregationOperation::SumOp], - query_hashes, - true, - &query_bounds, - &placeholders, - ) - .unwrap(); - let proof_2 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_2, get_tree_hash_from_proof(&proof_2),); - - // generate proof for node 1 (root of rows tree for node A) - let node_info_1 = NodeInfo::new( - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - Some(&HashOutput::try_from(hash_2.to_bytes()).unwrap()), - column_values[1][1], - column_values[0][1], - column_values[2][1], - ); - let node_info_0 = NodeInfo::new( - &HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(), - None, - None, - column_values[0][1], - column_values[0][1], - column_values[0][1], - ); - let hash_1 = node_info_1.compute_node_hash(secondary_index_id); - let subtree_proof = SubProof::new_child_proof(proof_2, ChildPosition::Right).unwrap(); - let input = Input::new_single_path( - subtree_proof, - Some(node_info_0), - Some(node_info_2), - node_info_1, - true, - &query_bounds, - ) - .unwrap(); - let proof_1 = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_1, get_tree_hash_from_proof(&proof_1),); - - // generate proof for node A (leaf of index tree) - let node_info_a = NodeInfo::new( - &HashOutput::try_from(hash_1.to_bytes()).unwrap(), - None, - None, - column_values[0][0], - column_values[0][0], - column_values[0][0], - ); - let hash_a = node_info_a.compute_node_hash(primary_index_id); - let subtree_proof = SubProof::new_embedded_tree_proof(proof_1).unwrap(); - let input = - Input::new_single_path(subtree_proof, None, None, node_info_a, false, &query_bounds) - .unwrap(); - let proof_a = params.generate_proof(input).unwrap(); - // check hash - assert_eq!(hash_a, get_tree_hash_from_proof(&proof_a),); - - // generate proof for node B rows tree - // all the nodes are in the range, so we generate proofs for each of the nodes - // generate proof for nodes 3 and 5: they are leaf nodes in the rows tree, so we directly use the universal circuit - let [proof_3, proof_5] = [&column_values[3], &column_values[5]] - .map(|values| gen_universal_circuit_proofs(values, true)); - // node 4 is not a leaf in the rows tree, so instead we need to first generate a proof for the row results using - // the universal circuit, and then we generate the proof for the rows tree node - let row_proof = gen_universal_circuit_proofs(&column_values[4], false); - let hash_3 = get_tree_hash_from_proof(&proof_3); - let hash_5 = get_tree_hash_from_proof(&proof_5); - let embedded_tree_hash = get_tree_hash_from_proof(&row_proof); - let input = Input::new_full_node(proof_3, proof_5, row_proof, true, &query_bounds).unwrap(); - let proof_4 = params.generate_proof(input).unwrap(); - // check hash - let node_info_4 = NodeInfo::new( - &HashOutput::try_from(embedded_tree_hash.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_3.to_bytes()).unwrap()), - Some(&HashOutput::try_from(hash_5.to_bytes()).unwrap()), - column_values[4][1], - column_values[3][1], - column_values[5][1], - ); - let hash_4 = node_info_4.compute_node_hash(secondary_index_id); - assert_eq!(hash_4, get_tree_hash_from_proof(&proof_4),); - - // generate proof for node B of the index tree (root node) - let node_info_root = NodeInfo::new( - &HashOutput::try_from(hash_4.to_bytes()).unwrap(), - Some(&HashOutput::try_from(hash_a.to_bytes()).unwrap()), - None, - column_values[4][0], - column_values[0][0], - column_values[5][0], - ); - let input = Input::new_partial_node( - proof_a, - proof_4, - None, - ChildPosition::Left, - false, - &query_bounds, - ) - .unwrap(); - let (root_proof, _) = ProofWithVK::deserialize(¶ms.generate_proof(input).unwrap()) - .unwrap() - .into(); - - check_pis(&root_proof.public_inputs, node_info_root, &column_values); + pub(crate) fn get_universal_circuit( + &self, + ) -> &UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + OutputNoAggCircuit, + > { + &self.universal_circuit } } diff --git a/verifiable-db/src/query/circuits/chunk_aggregation.rs b/verifiable-db/src/query/circuits/chunk_aggregation.rs new file mode 100644 index 000000000..7bac563e3 --- /dev/null +++ b/verifiable-db/src/query/circuits/chunk_aggregation.rs @@ -0,0 +1,418 @@ +use anyhow::Result; +use std::array; + +use itertools::Itertools; +use mp2_common::{ + proof::ProofWithVK, + public_inputs::PublicInputCommon, + serialization::{ + deserialize_array, deserialize_long_array, serialize_array, serialize_long_array, + }, + u256::CircuitBuilderU256, + utils::ToTargets, + D, F, +}; +use plonky2::{ + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; + +use crate::query::{ + pi_len, public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::aggregate_chunks::aggregate_chunks, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChunkAggregationWires { + #[serde( + serialize_with = "serialize_array", + deserialize_with = "deserialize_array" + )] + /// Boolean flag specifying whether the i-th chunk is dummy or not + is_non_dummy_chunk: [BoolTarget; NUM_CHUNKS], +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChunkAggregationCircuit { + /// Number of non-dummy chunks to be aggregated. Must be at + /// most `NUM_CHUNKS` + pub(crate) num_non_dummy_chunks: usize, +} + +impl + ChunkAggregationCircuit +{ + pub(crate) fn build( + b: &mut CircuitBuilder, + chunk_proofs: &[PublicInputsQueryCircuits; NUM_CHUNKS], + ) -> ChunkAggregationWires + where + [(); MAX_NUM_RESULTS - 1]:, + { + let is_non_dummy_chunk = array::from_fn(|_| b.add_virtual_bool_target_safe()); + + // Enforce the first chunk is non-dummy + b.assert_one(is_non_dummy_chunk[0].target); + + // build `RowChunkDataTarget` for first chunk + let mut row_chunk = chunk_proofs[0].to_row_chunk_target(); + // save query bounds of first chunk to check that they are the same across + // all the aggregated chunks + let min_query_primary = chunk_proofs[0].min_primary_target(); + let max_query_primary = chunk_proofs[0].max_primary_target(); + let min_query_secondary = chunk_proofs[0].min_secondary_target(); + let max_query_secondary = chunk_proofs[0].max_secondary_target(); + // save computational hash and placeholder hash of the first chunk to check + // that they are the same across all the aggregated chunks + let computational_hash = chunk_proofs[0].computational_hash_target(); + let placeholder_hash = chunk_proofs[0].placeholder_hash_target(); + // save identifiers of aggregation operations of the first chunk to check + // that they are the same across all the aggregated chunks + let ops_ids = chunk_proofs[0].operation_ids_target(); + for i in 1..NUM_CHUNKS { + let chunk_proof = &chunk_proofs[i]; + + let current_chunk = chunk_proof.to_dummy_row_chunk_target(b, is_non_dummy_chunk[i]); + row_chunk = aggregate_chunks( + b, + &row_chunk, + ¤t_chunk, + (&min_query_primary, &max_query_primary), + (&min_query_secondary, &max_query_secondary), + &ops_ids, + &is_non_dummy_chunk[i], + ); + // check the query bounds employed to prove the current chunk are the same + // as all other chunks + b.enforce_equal_u256(&chunk_proof.min_primary_target(), &min_query_primary); + b.enforce_equal_u256(&chunk_proof.max_primary_target(), &max_query_primary); + b.enforce_equal_u256(&chunk_proof.min_secondary_target(), &min_query_secondary); + b.enforce_equal_u256(&chunk_proof.max_secondary_target(), &max_query_secondary); + // check the same computational hash is associated to rows processed + // in all the chunks + b.connect_hashes(chunk_proof.computational_hash_target(), computational_hash); + // check the same placeholder hash is associated to rows processed in + // all the chunks + b.connect_hashes(chunk_proof.placeholder_hash_target(), placeholder_hash); + // check the same set of aggregation operations have been employed + // in all the chunks + chunk_proof + .operation_ids_target() + .into_iter() + .zip_eq(ops_ids) + .for_each(|(current_op, op)| b.connect(current_op, op)); + } + + let overflow_flag = { + let zero = b.zero(); + b.is_not_equal(row_chunk.chunk_outputs.num_overflows, zero) + }; + + PublicInputsQueryCircuits::::new( + &row_chunk.chunk_outputs.tree_hash.to_targets(), + &row_chunk.chunk_outputs.values.to_targets(), + &[row_chunk.chunk_outputs.count], + &ops_ids, + &row_chunk.left_boundary_row.to_targets(), + &row_chunk.right_boundary_row.to_targets(), + &min_query_primary.to_targets(), + &max_query_primary.to_targets(), + &min_query_secondary.to_targets(), + &max_query_secondary.to_targets(), + &[overflow_flag.target], + &computational_hash.to_targets(), + &placeholder_hash.to_targets(), + ) + .register(b); + + ChunkAggregationWires { is_non_dummy_chunk } + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &ChunkAggregationWires, + ) { + wires + .is_non_dummy_chunk + .iter() + .enumerate() + .for_each(|(i, wire)| pw.set_bool_target(*wire, i < self.num_non_dummy_chunks)); + } +} + +impl CircuitLogicWires + for ChunkAggregationWires +where + [(); MAX_NUM_RESULTS - 1]:, +{ + type CircuitBuilderParams = (); + + type Inputs = ChunkAggregationCircuit; + + const NUM_PUBLIC_INPUTS: usize = pi_len::(); + + fn circuit_logic( + builder: &mut CircuitBuilder, + verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHUNKS], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + let pis = verified_proofs + .map(|proof| PublicInputsQueryCircuits::from_slice(&proof.public_inputs)); + ChunkAggregationCircuit::build(builder, &pis) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChunkAggregationInputs { + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + pub(crate) chunk_proofs: [ProofWithVK; NUM_CHUNKS], + pub(crate) circuit: ChunkAggregationCircuit, +} + +#[cfg(test)] +mod tests { + use std::array; + + use itertools::Itertools; + use mp2_common::{array::ToField, utils::FromFields, C, D, F}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{ + field::types::Field, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, + }; + + use crate::{ + query::{ + computational_hash_ids::{AggregationOperation, Identifiers}, + public_inputs::PublicInputsQueryCircuits, + universal_circuit::universal_query_gadget::OutputValues, + utils::tests::aggregate_output_values, + }, + test_utils::random_aggregation_operations, + }; + + use super::{ChunkAggregationCircuit, ChunkAggregationWires}; + + const MAX_NUM_RESULTS: usize = 10; + const NUM_CHUNKS: usize = 5; + + #[derive(Clone, Debug)] + struct TestChunkAggregationWires { + pis: [Vec; NUM_CHUNKS], + inputs: ChunkAggregationWires, + } + + #[derive(Clone, Debug)] + struct TestChunkAggregationCircuit { + pis: [Vec; NUM_CHUNKS], + inputs: ChunkAggregationCircuit, + } + + impl + TestChunkAggregationCircuit + { + fn new(pis: &[Vec]) -> Self { + assert!( + !pis.is_empty(), + "there should be at least one chunk to prove" + ); + let dummy_pi = pis.last().unwrap(); + let inputs = ChunkAggregationCircuit { + num_non_dummy_chunks: pis.len(), + }; + let pis = array::from_fn(|i| pis.get(i).unwrap_or(dummy_pi).clone()); + Self { pis, inputs } + } + } + + impl UserCircuit + for TestChunkAggregationCircuit + where + [(); MAX_NUM_RESULTS - 1]:, + { + type Wires = TestChunkAggregationWires; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let raw_pis = array::from_fn(|_| { + c.add_virtual_targets( + PublicInputsQueryCircuits::::total_len(), + ) + }); + let pis = raw_pis + .iter() + .map(|pi| PublicInputsQueryCircuits::from_slice(pi)) + .collect_vec() + .try_into() + .unwrap(); + let inputs = ChunkAggregationCircuit::build(c, &pis); + + TestChunkAggregationWires { + pis: raw_pis, + inputs, + } + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.inputs.assign(pw, &wires.inputs); + self.pis + .iter() + .zip_eq(&wires.pis) + .for_each(|(values, targets)| pw.set_target_arr(targets, values)); + } + } + + fn test_chunk_aggregation_circuit(first_op_id: bool, dummy_chunks: bool) { + let mut ops = random_aggregation_operations(); + if first_op_id { + ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field() + } + let raw_pis = if dummy_chunks { + // if we test with dummy chunks to be aggregated, we generate `ACTUAL_NUM_CHUNKS <= NUM_CHUNKS` + // inputs, so that the remaining `NUM_CHUNKS - ACTUAL_NUM_CHUNKS` input slots are dummies + const NUM_ACTUAL_CHUNKS: usize = 3; + PublicInputsQueryCircuits::::sample_from_ops::( + &ops, + ) + .to_vec() + } else { + PublicInputsQueryCircuits::::sample_from_ops::(&ops) + .to_vec() + }; + + let circuit = TestChunkAggregationCircuit::::new(&raw_pis); + + let proof = run_circuit::(circuit); + + let input_pis = raw_pis + .iter() + .map(|pi| PublicInputsQueryCircuits::::from_slice(pi)) + .collect_vec(); + + let (expected_outputs, expected_overflow) = { + let outputs = input_pis + .iter() + .map(|pi| OutputValues::::from_fields(pi.to_values_raw())) + .collect_vec(); + let mut num_overflows = input_pis + .iter() + .fold(0, |acc, pi| pi.overflow_flag() as u32 + acc); + let expected_outputs = ops + .into_iter() + .enumerate() + .flat_map(|(i, op)| { + let (out_value, overflows) = aggregate_output_values(i, &outputs, op); + num_overflows += overflows; + out_value + }) + .collect_vec(); + ( + OutputValues::::from_fields(&expected_outputs), + num_overflows != 0, + ) + }; + + let expected_count = input_pis + .iter() + .fold(F::ZERO, |acc, pi| pi.num_matching_rows() + acc); + let expected_left_row = input_pis[0].to_left_row_raw(); + let expected_right_row = input_pis.last().unwrap().to_right_row_raw(); + + let result_pis = + PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); + + // check public inputs + assert_eq!( + result_pis.tree_hash(), + input_pis[0].tree_hash(), // tree hash is the same for all input_pis + ); + assert_eq!(result_pis.operation_ids(), ops,); + assert_eq!(result_pis.num_matching_rows(), expected_count,); + // check aggregated outputs + if ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field() { + assert_eq!( + result_pis.first_value_as_curve_point(), + expected_outputs.first_value_as_curve_point(), + ) + } else { + assert_eq!( + result_pis.first_value_as_u256(), + expected_outputs.first_value_as_u256(), + ) + }; + (1..MAX_NUM_RESULTS).for_each(|i| { + assert_eq!( + result_pis.value_at_index(i), + expected_outputs.value_at_index(i) + ) + }); + // check boundary rows + assert_eq!(result_pis.to_left_row_raw(), expected_left_row,); + assert_eq!(result_pis.to_right_row_raw(), expected_right_row,); + // check query bounds + assert_eq!( + result_pis.min_primary(), + input_pis[0].min_primary(), // query bounds are all the same in all `input_pis` + ); + assert_eq!( + result_pis.max_primary(), + input_pis[0].max_primary(), // query bounds are all the same in all `input_pis` + ); + assert_eq!( + result_pis.min_secondary(), + input_pis[0].min_secondary(), // query bounds are all the same in all `input_pis` + ); + assert_eq!( + result_pis.max_secondary(), + input_pis[0].max_secondary(), // query bounds are all the same in all `input_pis` + ); + // check overflow error + assert_eq!(result_pis.overflow_flag(), expected_overflow,); + // check computational hash + assert_eq!( + result_pis.computational_hash(), + input_pis[0].computational_hash(), // computational hash is the same in all `input_pis` + ); + // check placeholder hash + assert_eq!( + result_pis.placeholder_hash(), + input_pis[0].placeholder_hash(), + ); + } + + #[test] + fn test_chunk_aggregation_no_dummy_chunks() { + test_chunk_aggregation_circuit(false, false); + } + + #[test] + fn test_chunk_aggregation_dummy_chunks() { + test_chunk_aggregation_circuit(false, true); + } + + #[test] + fn test_chunk_aggregation_no_dummy_chunks_id_op() { + test_chunk_aggregation_circuit(true, false); + } + + #[test] + fn test_chunk_aggregation_dummy_chunks_id_op() { + test_chunk_aggregation_circuit(true, true); + } +} diff --git a/verifiable-db/src/query/circuits/mod.rs b/verifiable-db/src/query/circuits/mod.rs new file mode 100644 index 000000000..5dddff4df --- /dev/null +++ b/verifiable-db/src/query/circuits/mod.rs @@ -0,0 +1,340 @@ +pub(crate) mod chunk_aggregation; +pub(crate) mod non_existence; +pub(crate) mod row_chunk_processing; + +#[cfg(test)] +mod tests { + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{ + types::HashOutput, + utils::{FromFields, TryIntoBool}, + F, + }; + use mp2_test::{ + cells_tree::{compute_cells_tree_hash, TestCell}, + utils::gen_random_u256, + }; + use rand::thread_rng; + + use crate::{ + query::{ + computational_hash_ids::AggregationOperation, + merkle_path::tests::build_node, + universal_circuit::{ + universal_circuit_inputs::{ + BasicOperation, Placeholders, ResultStructure, RowCells, + }, + universal_query_gadget::OutputValues, + }, + utils::{NodeInfo, QueryBounds}, + }, + test_utils::gen_values_in_range, + }; + + /// Data structure employed to represent a node of a rows tree in the tests + #[derive(Clone, Debug)] + pub(crate) struct TestRowsTreeNode { + pub(crate) node: NodeInfo, + pub(crate) values: Vec, + } + /// Data structure employed to represent a node of the index tree in the tests + #[derive(Clone, Debug)] + pub(crate) struct TestIndexTreeNode { + pub(crate) node: NodeInfo, + pub(crate) rows_tree: Vec, + } + + /// Build a test index tree structured as follows: + /// 1 + /// 0 2 + /// where 1 and 2 stores values in the primary index query range + /// Then, node 0 stores the following rows tree: + /// A + /// B + /// C + /// With only B being in the secondary index query range + /// Node 1 stores the following rows tree: + /// A + /// B C + /// D + /// Where nodes A and C are in the secondary index query range + /// Node 2 stores the following rows tree: + /// A + /// B + /// C D + /// Where all nodes except for C are in secondary index query range + pub(crate) async fn build_test_tree( + bounds: &QueryBounds, + column_ids: &[F], + ) -> [TestIndexTreeNode; 3] { + // sample primary index values + let rng = &mut thread_rng(); + let [value_0] = gen_values_in_range(rng, U256::ZERO, bounds.min_query_primary()); // value of node 0 must be out of range + let [value_1, value_2] = + gen_values_in_range(rng, bounds.min_query_primary(), bounds.max_query_primary()); + // sample secondary index values for rows tree of node 0 + let [value_0c, value_0a] = + gen_values_in_range(rng, *bounds.max_query_secondary().value(), U256::MAX); + let [value_0b] = gen_values_in_range( + rng, + *bounds.min_query_secondary().value(), + *bounds.max_query_secondary().value(), + ); + // sample secondary index values for rows tree of node 1 + let [value_1b] = + gen_values_in_range(rng, U256::ZERO, *bounds.min_query_secondary().value()); + let [value_1d] = gen_values_in_range(rng, *bounds.max_query_secondary().value(), U256::MAX); + let [value_1a, value_1c] = gen_values_in_range( + rng, + *bounds.min_query_secondary().value(), + *bounds.max_query_secondary().value(), + ); + // sample secondary index values for rows tree of node 2 + let [value_2c] = + gen_values_in_range(rng, U256::ZERO, *bounds.min_query_secondary().value()); + let [value_2b, value_2d, value_2a] = gen_values_in_range( + rng, + *bounds.min_query_secondary().value(), + *bounds.max_query_secondary().value(), + ); + let primary_index = column_ids[0]; + let secondary_index = column_ids[1]; + let build_cells = async |primary_index_value: U256, secondary_index_value: U256| { + let rng = &mut thread_rng(); + let (mut cell_values, cells): (Vec<_>, Vec<_>) = column_ids + .iter() + .skip(2) + .map(|id| { + let column_value = gen_random_u256(rng); + (column_value, TestCell::new(column_value, *id)) + }) + .unzip(); + let mut values = vec![primary_index_value, secondary_index_value]; + values.append(&mut cell_values); + let hash = compute_cells_tree_hash(cells).await; + (values, hash) + }; + // build row 0C + let (values, cell_tree_hash) = build_cells(value_0, value_0c).await; + let node_0c = TestRowsTreeNode { + node: build_node( + None, + None, + value_0c, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 0B + let (values, cell_tree_hash) = build_cells(value_0, value_0b).await; + let node_0b = TestRowsTreeNode { + node: build_node( + None, + Some(&node_0c.node), + value_0b, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 0A + let (values, cell_tree_hash) = build_cells(value_0, value_0a).await; + let node_0a = TestRowsTreeNode { + node: build_node( + Some(&node_0b.node), + None, + value_0a, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build node 0 + let node_0 = TestIndexTreeNode { + node: build_node( + None, + None, + value_0, + HashOutput::from(node_0a.node.compute_node_hash(secondary_index)), + primary_index, + ), + rows_tree: vec![node_0a, node_0b, node_0c], + }; + // build row 2C + let (values, cell_tree_hash) = build_cells(value_2, value_2c).await; + let node_2c = TestRowsTreeNode { + node: build_node( + None, + None, + value_2c, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 2D + let (values, cell_tree_hash) = build_cells(value_2, value_2d).await; + let node_2d = TestRowsTreeNode { + node: build_node( + None, + None, + value_2d, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 2B + let (values, cell_tree_hash) = build_cells(value_2, value_2b).await; + let node_2b = TestRowsTreeNode { + node: build_node( + Some(&node_2c.node), + Some(&node_2d.node), + value_2b, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 2A + let (values, cell_tree_hash) = build_cells(value_2, value_2a).await; + let node_2a = TestRowsTreeNode { + node: build_node( + Some(&node_2b.node), + None, + value_2a, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build node 2 + let node_2 = TestIndexTreeNode { + node: build_node( + None, + None, + value_2, + HashOutput::from(node_2a.node.compute_node_hash(secondary_index)), + primary_index, + ), + rows_tree: vec![node_2a, node_2b, node_2c, node_2d], + }; + // build row 1D + let (values, cell_tree_hash) = build_cells(value_1, value_1d).await; + let node_1d = TestRowsTreeNode { + node: build_node( + None, + None, + value_1d, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 1B + let (values, cell_tree_hash) = build_cells(value_1, value_1b).await; + let node_1b = TestRowsTreeNode { + node: build_node( + None, + None, + value_1b, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 1C + let (values, cell_tree_hash) = build_cells(value_1, value_1c).await; + let node_1c = TestRowsTreeNode { + node: build_node( + None, + Some(&node_1d.node), + value_1c, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build row 1A + let (values, cell_tree_hash) = build_cells(value_1, value_1a).await; + let node_1a = TestRowsTreeNode { + node: build_node( + Some(&node_1b.node), + Some(&node_1c.node), + value_1a, + HashOutput::from(cell_tree_hash), + secondary_index, + ), + values, + }; + // build node 1 + let node_1 = TestIndexTreeNode { + node: build_node( + Some(&node_0.node), + Some(&node_2.node), + value_1, + HashOutput::from(node_1a.node.compute_node_hash(secondary_index)), + primary_index, + ), + rows_tree: vec![node_1a, node_1b, node_1c, node_1d], + }; + + [node_0, node_1, node_2] + } + + /// Compute predicate value and output values for a given row with cells `row_cells`. + /// Return also a flag sepcifying whether arithmetic errors have occurred during the computation or not + pub(crate) fn compute_output_values_for_row( + row_cells: &RowCells, + predicate_operations: &[BasicOperation], + results: &ResultStructure, + placeholders: &Placeholders, + ) -> (bool, bool, OutputValues) + where + [(); MAX_NUM_RESULTS - 1]:, + { + let column_values = row_cells + .to_cells() + .into_iter() + .map(|cell| cell.value) + .collect_vec(); + let (res, predicate_err) = + BasicOperation::compute_operations(predicate_operations, &column_values, placeholders) + .unwrap(); + let predicate_value = res.last().unwrap().try_into_bool().unwrap(); + + let (res, result_err) = results + .compute_output_values(&column_values, placeholders) + .unwrap(); + + let aggregation_ops = results.aggregation_operations(); + + let output_values = res + .iter() + .zip(aggregation_ops.iter()) + .map(|(value, agg_op)| { + // if predicate_value is satisfied, then the actual output value + // is exposed as public input + if predicate_value { + *value + } else { + // otherwise, we just expose identity values for the given aggregation + // operation to ensure that the current record doesn't affect the + // aggregated result + U256::from_fields( + AggregationOperation::from_fields(&[*agg_op]) + .identity_value() + .as_slice(), + ) + } + }) + .collect_vec(); + ( + predicate_value, + predicate_err | result_err, + OutputValues::::new_aggregation_outputs(&output_values), + ) + } +} diff --git a/verifiable-db/src/query/circuits/non_existence.rs b/verifiable-db/src/query/circuits/non_existence.rs new file mode 100644 index 000000000..18cc90303 --- /dev/null +++ b/verifiable-db/src/query/circuits/non_existence.rs @@ -0,0 +1,738 @@ +use anyhow::Result; +use std::array; + +use alloy::primitives::U256; +use mp2_common::{ + poseidon::empty_poseidon_hash, + public_inputs::PublicInputCommon, + serialization::{deserialize, deserialize_long_array, serialize, serialize_long_array}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::ToTargets, + D, F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{Deserialize, Serialize}; + +use crate::query::{ + api::TreePathInputs, + merkle_path::{ + MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfoTarget, + }, + output_computation::compute_dummy_output_targets, + pi_len, + public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}, + universal_circuit::{ + ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, + }, + utils::QueryBounds, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NonExistenceWires +where + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + index_path: MerklePathWithNeighborsTargetInputs, + index_node_value: UInt256Target, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + index_node_subtree_hash: HashOutTarget, + primary_index_id: Target, + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + ops: [Target; MAX_NUM_RESULTS], + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + computational_hash: ComputationalHashTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + placeholder_hash: PlaceholderHashTarget, + min_query_primary: UInt256Target, + max_query_primary: UInt256Target, +} + +/// Circuit employed to prove the non-existence of a node in the index tree with +/// a value in the query range [min_query_primary, max_query_primary] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NonExistenceCircuit +where + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + // path of the index tree node employed to prove non-existence + index_path: MerklePathWithNeighborsGadget, + // Value of the index tree node employed to prove non-existence + index_node_value: U256, + // Hash of the subtree stored in the index tree node employed to + // prove non-existence + index_node_subtree_hash: HashOut, + // Integer identifier of primary index column + primary_index_id: F, + // Set of identifiers of the aggregation operations + // (provided only to be exposed for public input compliance) + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + ops: [F; MAX_NUM_RESULTS], + // Computational hash associated to the query + // (provided only to be exposed for public input compliance) + computational_hash: ComputationalHash, + // Placeholder hash associated to the placeholders employed + // in the query (provided only to be exposed for public + // input compliance) + placeholder_hash: PlaceholderHash, + // lower bound of the query range + min_query_primary: U256, + // upper bound of the query range + max_query_primary: U256, +} + +impl + NonExistenceCircuit +where + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + pub(crate) fn new( + path: &TreePathInputs, + primary_index: F, + aggregation_ops: [F; MAX_NUM_RESULTS], + computational_hash: ComputationalHash, + placeholder_hash: PlaceholderHash, + query_bounds: &QueryBounds, + ) -> Result { + Ok(Self { + index_path: MerklePathWithNeighborsGadget::new( + &path.path, + &path.siblings, + &path.node_info, + path.children, + )?, + index_node_value: path.node_info.value, + index_node_subtree_hash: path.node_info.embedded_tree_hash, + primary_index_id: primary_index, + ops: aggregation_ops, + computational_hash, + placeholder_hash, + min_query_primary: query_bounds.min_query_primary(), + max_query_primary: query_bounds.max_query_primary(), + }) + } + + pub(crate) fn build( + b: &mut CircuitBuilder, + ) -> NonExistenceWires { + let index_node_value = b.add_virtual_u256_unsafe(); // unsafe is ok since it's hashed + // in `MerklePathGadgetWithNeighbors` + let [index_node_subtree_hash, computational_hash, placeholder_hash] = + array::from_fn(|_| b.add_virtual_hash()); + let primary_index = b.add_virtual_target(); + let ops = b.add_virtual_target_arr::(); + let [min_query_primary, max_query_primary] = b.add_virtual_u256_arr_unsafe(); // unsafe is ok + // since they are exposed as public inputs + let index_path = MerklePathWithNeighborsGadget::build( + b, + index_node_value.clone(), + index_node_subtree_hash, + primary_index, + ); + // check that index_node_value is out of range + let smaller_than_min = b.is_less_than_u256(&index_node_value, &min_query_primary); + let bigger_than_max = b.is_less_than_u256(&max_query_primary, &index_node_value); + let is_out_of_range = b.or(smaller_than_min, bigger_than_max); + b.assert_one(is_out_of_range.target); + let predecessor_info = &index_path.predecessor_info; + let successor_info = &index_path.successor_info; + // assert NOT predecessor_info.is_found OR predecessor_info.value < min_query_primary + // equivalent to: assert predecessor_info.is_found AND predecessor_info.value < min_query_primary == predecessor_info.is_found + let predecessor_smaller = b.is_less_than_u256(&predecessor_info.value, &min_query_primary); + let predecessor_flag = b.and(predecessor_info.is_found, predecessor_smaller); + b.connect(predecessor_flag.target, predecessor_info.is_found.target); + // assert NOT successor_info.is_found OR successor_info.value > max_query_primary + // equivalent to: assert successor_info.is_found AND successor_info.value > max_query_primary == successor_info.is_found + let successor_bigger = b.is_less_than_u256(&max_query_primary, &successor_info.value); + let successor_flag = b.and(successor_info.is_found, successor_bigger); + b.connect(successor_flag.target, successor_info.is_found.target); + // compute dummy output values + let outputs = compute_dummy_output_targets(b, &ops); + + // generate fake `BoundaryRowNodeInfo` for a fake rows tree node, to satisfy + // the constraints in the revelation circuit + let row_node_data = { + // We simulate that the rows tree node associated to this row is the minimum node in the rows tree, + // which means there is no predecessor + let row_node_predecessor = NeighborInfoTarget::new_dummy_predecessor(b); + // We simulate that the rows tree node associated to this row is also the maximum node in the rows + // tree, which means there is no successor + let row_node_successor = NeighborInfoTarget::new_dummy_successor(b); + BoundaryRowNodeInfoTarget { + end_node_hash: b.constant_hash(*empty_poseidon_hash()), + predecessor_info: row_node_predecessor, + successor_info: row_node_successor, + } + }; + let boundary_row = BoundaryRowDataTarget { + row_node_info: row_node_data, + index_node_info: BoundaryRowNodeInfoTarget::from(&index_path), + }; + + // expose public inputs + let zero = b.zero(); + // query bounds on secondary index needs to be exposed as public inputs, but they + // can be dummy values since they are un-used in this circuit + let min_secondary = b.zero_u256(); + let max_secondary = b.constant_u256(U256::MAX); + PublicInputsQueryCircuits::::new( + &index_path.root.to_targets(), + &outputs, + &[zero], // there are no matching rows + &ops, + &boundary_row.to_targets(), + &boundary_row.to_targets(), + &min_query_primary.to_targets(), + &max_query_primary.to_targets(), + &min_secondary.to_targets(), + &max_secondary.to_targets(), + &[zero], // no arithmetic operations done, so no error occurred + &computational_hash.to_targets(), + &placeholder_hash.to_targets(), + ) + .register(b); + + NonExistenceWires { + index_path: index_path.inputs, + index_node_value, + index_node_subtree_hash, + primary_index_id: primary_index, + ops, + computational_hash, + placeholder_hash, + min_query_primary, + max_query_primary, + } + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &NonExistenceWires, + ) { + self.index_path.assign(pw, &wires.index_path); + [ + (self.index_node_value, &wires.index_node_value), + (self.min_query_primary, &wires.min_query_primary), + (self.max_query_primary, &wires.max_query_primary), + ] + .into_iter() + .for_each(|(value, target)| pw.set_u256_target(target, value)); + pw.set_target_arr(&wires.ops, &self.ops); + pw.set_target(wires.primary_index_id, self.primary_index_id); + [ + (self.index_node_subtree_hash, wires.index_node_subtree_hash), + (self.computational_hash, wires.computational_hash), + (self.placeholder_hash, wires.placeholder_hash), + ] + .into_iter() + .for_each(|(value, target)| pw.set_hash_target(target, value)); + } +} + +impl CircuitLogicWires + for NonExistenceWires +where + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + type CircuitBuilderParams = (); + + type Inputs = NonExistenceCircuit; + + const NUM_PUBLIC_INPUTS: usize = pi_len::(); + + fn circuit_logic( + builder: &mut CircuitBuilder, + _verified_proofs: [&ProofWithPublicInputsTarget; 0], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + NonExistenceCircuit::build(builder) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> anyhow::Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use alloy::primitives::U256; + use mp2_common::{check_panic, poseidon::empty_poseidon_hash, utils::ToFields, C, D, F}; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::gen_random_field_hash, + }; + use plonky2::{ + field::types::{Field, Sample}, + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputs}, + }; + use rand::thread_rng; + + use crate::{ + query::{ + api::TreePathInputs, + merkle_path::{tests::generate_test_tree, NeighborInfo}, + output_computation::tests::compute_dummy_output_values, + public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::{BoundaryRowData, BoundaryRowNodeInfo}, + universal_circuit::universal_circuit_inputs::Placeholders, + utils::{ChildPosition, QueryBounds}, + }, + test_utils::{gen_values_in_range, random_aggregation_operations}, + }; + + use super::{NonExistenceCircuit, NonExistenceWires}; + + const INDEX_TREE_MAX_DEPTH: usize = 15; + const MAX_NUM_RESULTS: usize = 10; + + impl UserCircuit for NonExistenceCircuit { + type Wires = NonExistenceWires; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + NonExistenceCircuit::build(c) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.assign(pw, wires); + } + } + + #[test] + fn test_non_existence_circuit() { + let primary_index = F::rand(); + let rng = &mut thread_rng(); + let [computational_hash, placeholder_hash] = array::from_fn(|_| gen_random_field_hash()); + let ops = random_aggregation_operations(); + // generate min_query_primary and max_query_primary + let [min_query_primary, max_query_primary] = + gen_values_in_range(rng, U256::from(42), U256::MAX - U256::from(42)); + let query_bounds = QueryBounds::new( + &Placeholders::new_empty(min_query_primary, max_query_primary), + None, + None, + ) + .unwrap(); + // generate a test index tree with all nodes bigger than max_primary + let [node_a, node_b, _node_c, node_d, node_e, _node_f, _node_g] = generate_test_tree( + primary_index, + Some((max_query_primary + U256::from(1), U256::MAX)), + ); + // we prove non-existence employing the minimum node of the tree as the proven node, which is node_e + let path_e = vec![ + (node_d, ChildPosition::Left), + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let merkle_path_e = TreePathInputs::new(node_e, path_e, [None, None]); + let circuit = NonExistenceCircuit::new( + &merkle_path_e, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + let proof = run_circuit::(circuit); + + let check_public_inputs = |proof: &ProofWithPublicInputs, + expected_root, + expected_index_node_info: BoundaryRowNodeInfo, + expected_query_bounds: &QueryBounds, + test_name: &str| { + let pis = + PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); + assert_eq!( + pis.tree_hash(), + expected_root, + "failed for test {test_name}", + ); + let expected_outputs = compute_dummy_output_values(&ops); + assert_eq!( + pis.to_values_raw(), + &expected_outputs, + "failed for test {test_name}", + ); + assert_eq!( + pis.num_matching_rows(), + F::ZERO, + "failed for test {test_name}", + ); + let expected_boundary_row = { + // build the same dummy `BoundaryRowNodeInfo` built inside the circuit + let dummy_row_node_info = BoundaryRowNodeInfo { + end_node_hash: *empty_poseidon_hash(), + predecessor_info: NeighborInfo::new_dummy_predecessor(), + successor_info: NeighborInfo::new_dummy_successor(), + }; + BoundaryRowData { + row_node_info: dummy_row_node_info, + index_node_info: expected_index_node_info, + } + } + .to_fields(); + assert_eq!( + pis.to_left_row_raw(), + expected_boundary_row, + "failed for test {test_name}", + ); + assert_eq!( + pis.to_right_row_raw(), + expected_boundary_row, + "failed for test {test_name}", + ); + assert_eq!( + pis.min_primary(), + expected_query_bounds.min_query_primary(), + "failed for test {test_name}", + ); + assert_eq!( + pis.max_primary(), + expected_query_bounds.max_query_primary(), + "failed for test {test_name}", + ); + assert!(!pis.overflow_flag(), "failed for test {test_name}"); + assert_eq!( + pis.computational_hash(), + computational_hash, + "failed for test {test_name}", + ); + assert_eq!( + pis.placeholder_hash(), + placeholder_hash, + "failed for test {test_name}", + ); + }; + let expected_root = node_a.compute_node_hash(primary_index); + let expected_index_node_info = { + // node_e has no predecessor + let predecessor_e = NeighborInfo::new_dummy_predecessor(); + // node_e successor is node_d, which is in the path + let node_d_hash = node_d.compute_node_hash(primary_index); + let successor_e = NeighborInfo::new(node_d.value, Some(node_d_hash)); + BoundaryRowNodeInfo { + end_node_hash: node_e.compute_node_hash(primary_index), + predecessor_info: predecessor_e, + successor_info: successor_e, + } + }; + + check_public_inputs( + &proof, + expected_root, + expected_index_node_info, + &query_bounds, + "all bigger", + ); + + // generate a test index tree with all nodes smaller than min_query_primary + let [node_a, _node_b, node_c, _node_d, _node_e, _node_f, node_g] = generate_test_tree( + primary_index, + Some((U256::ZERO, min_query_primary - U256::from(1))), + ); + // we prove non-existence employing the maximum node of the tree as the proven node, which is node_g + let path_g = vec![ + (node_c, ChildPosition::Right), + (node_a, ChildPosition::Right), + ]; + let merkle_path_g = TreePathInputs::new(node_g, path_g, [None, None]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_g, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + let proof = run_circuit::(circuit); + + let expected_index_node_info = { + // node_g predecessor is node_c, which is in the path + let node_c_hash = node_c.compute_node_hash(primary_index); + let predecessor_g = NeighborInfo::new(node_c.value, Some(node_c_hash)); + // node_g has no successor + let successor_g = NeighborInfo::new_dummy_successor(); + BoundaryRowNodeInfo { + end_node_hash: node_g.compute_node_hash(primary_index), + predecessor_info: predecessor_g, + successor_info: successor_g, + } + }; + let expected_root = node_a.compute_node_hash(primary_index); + check_public_inputs( + &proof, + expected_root, + expected_index_node_info, + &query_bounds, + "all smaller", + ); + + // now, we test non-existence over a tree where some nodes are smaller than min_query_primary, and all other nodes are + // bigger than max_query_primary + // We generate a test tree with random values, and then we set min_query_primary and max_query_primary to values which are + // between node_f.value and node_b.value + let ([node_a, node_b, node_c, node_d, node_e, node_f, _node_g], query_bounds) = loop { + let [node_a, node_b, node_c, node_d, node_e, node_f, node_g] = + generate_test_tree(primary_index, None); + if node_b.value.checked_sub(node_f.value).unwrap() > U256::from(2) { + // if there is room between node_f.value and node_b.value, we + // set min_query_primary = node_f.value + 1 and max_query_primary = node_b.value - 1 + let min_query_primary = node_f.value + U256::from(1); + let max_query_primary = node_b.value - U256::from(1); + + break ( + [node_a, node_b, node_c, node_d, node_e, node_f, node_g], + QueryBounds::new( + &Placeholders::new_empty(min_query_primary, max_query_primary), + None, + None, + ) + .unwrap(), + ); + } + // otherwise, we need to re-generate the tree + }; + // in this case, we can use either node_b or node_f to prove non-existence + // prove with node_f + let path_f = vec![ + (node_d, ChildPosition::Right), + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let merkle_path_f = TreePathInputs::new(node_f, path_f, [None, None]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_f, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + let proof = run_circuit::(circuit); + let expected_index_node_info = { + // node_f predecessor is node_d, which is in the path + let node_d_hash = node_d.compute_node_hash(primary_index); + let predecessor_f = NeighborInfo::new(node_d.value, Some(node_d_hash)); + // node_f successor is node_b, which is in the path + let node_b_hash = node_b.compute_node_hash(primary_index); + let successor_f = NeighborInfo::new(node_b.value, Some(node_b_hash)); + BoundaryRowNodeInfo { + end_node_hash: node_f.compute_node_hash(primary_index), + predecessor_info: predecessor_f, + successor_info: successor_f, + } + }; + let expected_root = node_a.compute_node_hash(primary_index); + check_public_inputs( + &proof, + expected_root, + expected_index_node_info, + &query_bounds, + "smaller predecessor", + ); + + // we try to prove also with node_b + let path_b = vec![(node_a, ChildPosition::Left)]; + let merkle_path_b = TreePathInputs::new(node_b, path_b, [Some(node_d), None]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_b, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + let proof = run_circuit::(circuit); + let expected_index_node_info = { + // node_b predecessor is node_f, which is not in the path + let predecessor_b = NeighborInfo::new(node_f.value, None); + // node_b successor is node_a, which is in the path + let successor_b = NeighborInfo::new(node_a.value, Some(expected_root)); + BoundaryRowNodeInfo { + end_node_hash: node_b.compute_node_hash(primary_index), + predecessor_info: predecessor_b, + successor_info: successor_b, + } + }; + check_public_inputs( + &proof, + expected_root, + expected_index_node_info, + &query_bounds, + "bigger successor", + ); + + // negative test: check that if there are nodes in the query range, then the circuit fail for each node in + // the tree + // set min_query_primary = node_f.value, max_query_primary = node_a.value + let query_bounds = QueryBounds::new( + &Placeholders::new_empty(node_f.value, node_a.value), + None, + None, + ) + .unwrap(); + // try generate prove with node_a + let path_a = vec![]; + let merkle_path_a = TreePathInputs::new(node_a, path_a, [Some(node_b), Some(node_c)]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_a, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_a" + ); + + // try to generate proof with node_b + let circuit = NonExistenceCircuit::new( + &merkle_path_b, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_b" + ); + + // try generate prove with node_c + let path_c = vec![(node_a, ChildPosition::Right)]; + let merkle_path_c = TreePathInputs::new(node_c, path_c, [None, Some(node_g)]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_c, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_c" + ); + + // try generate prove with node_d + let path_d = vec![(node_b, ChildPosition::Left), (node_a, ChildPosition::Left)]; + let merkle_path_d = TreePathInputs::new(node_d, path_d, [Some(node_e), Some(node_f)]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_d, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_d" + ); + + // try generate prove with node_e + let path_e = vec![ + (node_d, ChildPosition::Left), + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let merkle_path_e = TreePathInputs::new(node_e, path_e, [None, None]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_e, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_e" + ); + + // try to generate proof with node_f + let circuit = NonExistenceCircuit::new( + &merkle_path_f, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_f" + ); + + // try to generate proof with node_g + // try generate prove with node_d + let path_g = vec![ + (node_c, ChildPosition::Right), + (node_a, ChildPosition::Right), + ]; + let merkle_path_g = TreePathInputs::new(node_g, path_g, [None, None]); + + let circuit = NonExistenceCircuit::new( + &merkle_path_g, + primary_index, + ops, + computational_hash, + placeholder_hash, + &query_bounds, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail for node_g" + ); + } +} diff --git a/verifiable-db/src/query/circuits/row_chunk_processing.rs b/verifiable-db/src/query/circuits/row_chunk_processing.rs new file mode 100644 index 000000000..685d45a3e --- /dev/null +++ b/verifiable-db/src/query/circuits/row_chunk_processing.rs @@ -0,0 +1,1576 @@ +use std::iter::repeat; + +use alloy::primitives::U256; +use itertools::Itertools; +use plonky2::{ + iop::{target::Target, witness::PartialWitness}, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, +}; +use recursion_framework::circuit_builder::CircuitLogicWires; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::query::{ + computational_hash_ids::ColumnIDs, + pi_len, + public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::{ + aggregate_chunks::aggregate_chunks, + row_process_gadget::{RowProcessingGadgetInputWires, RowProcessingGadgetInputs}, + RowChunkDataTarget, + }, + universal_circuit::{ + universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, + universal_query_gadget::{ + OutputComponent, UniversalQueryHashInputWires, UniversalQueryHashInputs, + }, + }, + utils::QueryBounds, +}; + +use mp2_common::{ + public_inputs::PublicInputCommon, + serialization::{deserialize_long_array, serialize_long_array}, + utils::ToTargets, + D, F, +}; + +use anyhow::{ensure, Result}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RowChunkProcessingWires< + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + row_inputs: + [RowProcessingGadgetInputWires; + NUM_ROWS], + universal_query_inputs: UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RowChunkProcessingCircuit< + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + row_inputs: [RowProcessingGadgetInputs< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >; NUM_ROWS], + universal_query_inputs: UniversalQueryHashInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + min_query_primary: U256, + max_query_primary: U256, +} + +impl< + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, + > + RowChunkProcessingCircuit< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + pub(crate) fn new( + row_inputs: Vec< + RowProcessingGadgetInputs< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >, + >, + column_ids: &ColumnIDs, + predicate_operations: &[BasicOperation], + placeholders: &Placeholders, + query_bounds: &QueryBounds, + results: &ResultStructure, + ) -> Result { + let universal_query_inputs = UniversalQueryHashInputs::new( + column_ids, + predicate_operations, + placeholders, + query_bounds, + results, + )?; + + ensure!( + !row_inputs.is_empty(), + "Row chunk circuit input should be at least 1 row" + ); + // dummy row used to pad `row_inputs` to `num_rows` is just copied from the last + // real row provided as input + let dummy_row = row_inputs.last().unwrap().clone_to_dummy_row(); + Ok(Self { + row_inputs: row_inputs + .into_iter() + .chain(repeat(dummy_row.clone())) + .take(NUM_ROWS) + .collect_vec() + .try_into() + .unwrap(), + universal_query_inputs, + min_query_primary: query_bounds.min_query_primary(), + max_query_primary: query_bounds.max_query_primary(), + }) + } + + pub(crate) fn build( + b: &mut CircuitBuilder, + ) -> RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > { + let query_input_wires = UniversalQueryHashInputs::build(b); + let first_row_wires = RowProcessingGadgetInputs::build( + b, + &query_input_wires.input_wires, + &query_input_wires.min_secondary, + &query_input_wires.max_secondary, + ); + // enforce first row is non-dummy + b.assert_one( + first_row_wires + .value_wires + .input_wires + .is_non_dummy_row + .target, + ); + + let mut row_inputs = vec![RowProcessingGadgetInputWires::from(&first_row_wires)]; + + let row_chunk: RowChunkDataTarget = first_row_wires.into(); + + let row_chunk = (1..NUM_ROWS).fold(row_chunk, |chunk, _| { + let row_wires = RowProcessingGadgetInputs::build( + b, + &query_input_wires.input_wires, + &query_input_wires.min_secondary, + &query_input_wires.max_secondary, + ); + row_inputs.push(RowProcessingGadgetInputWires::from(&row_wires)); + let is_second_non_dummy = row_wires.value_wires.input_wires.is_non_dummy_row; + let current_chunk: RowChunkDataTarget = row_wires.into(); + aggregate_chunks( + b, + &chunk, + ¤t_chunk, + ( + &query_input_wires.input_wires.min_query_primary, + &query_input_wires.input_wires.max_query_primary, + ), + ( + &query_input_wires.min_secondary, + &query_input_wires.max_secondary, + ), + &query_input_wires.agg_ops_ids, + &is_second_non_dummy, + ) + }); + // compute overflow flag + let overflow = { + let num_overflows = b.add( + query_input_wires.num_bound_overflows, + row_chunk.chunk_outputs.num_overflows, + ); + let zero = b.zero(); + b.is_not_equal(num_overflows, zero) + }; + + PublicInputsQueryCircuits::::new( + &row_chunk.chunk_outputs.tree_hash.to_targets(), + &row_chunk.chunk_outputs.values.to_targets(), + &[row_chunk.chunk_outputs.count], + &query_input_wires.agg_ops_ids, + &row_chunk.left_boundary_row.to_targets(), + &row_chunk.right_boundary_row.to_targets(), + &query_input_wires.input_wires.min_query_primary.to_targets(), + &query_input_wires.input_wires.max_query_primary.to_targets(), + &query_input_wires.min_secondary.to_targets(), + &query_input_wires.max_secondary.to_targets(), + &[overflow.target], + &query_input_wires.computational_hash.to_targets(), + &query_input_wires.placeholder_hash.to_targets(), + ) + .register(b); + + RowChunkProcessingWires { + row_inputs: row_inputs.try_into().unwrap(), + universal_query_inputs: query_input_wires.input_wires, + } + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + ) { + self.row_inputs + .iter() + .zip(&wires.row_inputs) + .for_each(|(value, target)| value.assign(pw, target)); + self.universal_query_inputs + .assign(pw, &wires.universal_query_inputs); + } +} + +impl< + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize + DeserializeOwned, + > CircuitLogicWires + for RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + type CircuitBuilderParams = (); + + type Inputs = RowChunkProcessingCircuit< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >; + + const NUM_PUBLIC_INPUTS: usize = pi_len::(); + + fn circuit_logic( + builder: &mut CircuitBuilder, + _verified_proofs: [&ProofWithPublicInputsTarget; 0], + _builder_parameters: Self::CircuitBuilderParams, + ) -> Self { + RowChunkProcessingCircuit::build(builder) + } + + fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { + inputs.assign(pw, self); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{array, iter::once}; + + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{ + array::ToField, + check_panic, + group_hashing::map_to_curve_point, + types::HashOutput, + utils::{FromFields, ToFields, TryIntoBool}, + C, D, F, + }; + use mp2_test::{ + cells_tree::{compute_cells_tree_hash, TestCell}, + circuit::{run_circuit, UserCircuit}, + utils::{gen_random_u256, random_vector}, + }; + use plonky2::{ + field::types::{Field, PrimeField64, Sample}, + plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, + }; + use plonky2_ecgfp5::curve::curve::Point; + use rand::thread_rng; + + use crate::query::{ + circuits::{ + row_chunk_processing::{RowChunkProcessingCircuit, UniversalQueryHashInputs}, + tests::{build_test_tree, compute_output_values_for_row}, + }, + computational_hash_ids::{ + AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, + }, + merkle_path::{MerklePathWithNeighborsGadget, NeighborInfo}, + public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::{ + row_process_gadget::RowProcessingGadgetInputs, BoundaryRowData, BoundaryRowNodeInfo, + }, + universal_circuit::{ + output_no_aggregation::Circuit as NoAggOutputCircuit, + output_with_aggregation::Circuit as AggOutputCircuit, + universal_circuit_inputs::{ + BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, + ResultStructure, RowCells, + }, + universal_query_circuit::placeholder_hash, + universal_query_gadget::CurveOrU256, + ComputationalHash, + }, + utils::{tests::aggregate_output_values, ChildPosition, QueryBoundSource, QueryBounds}, + }; + + use super::{OutputComponent, RowChunkProcessingWires}; + + impl< + const NUM_ROWS: usize, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, + > UserCircuit + for RowChunkProcessingCircuit< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > + where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + { + type Wires = RowChunkProcessingWires< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + Self::build(c) + } + + fn prove(&self, pw: &mut plonky2::iop::witness::PartialWitness, wires: &Self::Wires) { + self.assign(pw, wires) + } + } + + const NUM_ROWS: usize = 5; + const ROW_TREE_MAX_DEPTH: usize = 10; + const INDEX_TREE_MAX_DEPTH: usize = 15; + const MAX_NUM_COLUMNS: usize = 30; + const MAX_NUM_PREDICATE_OPS: usize = 20; + const MAX_NUM_RESULT_OPS: usize = 30; + const MAX_NUM_RESULTS: usize = 10; + + // SELECT SUM(C1*C2-C2/C3)), AVG(C1*C2), MIN(C1/C4+3), MAX(C4-$1), AVG(C5) FROM T WHERE (C5 > 5 AND C1*C2 <= C3+C4 OR C3 == $2) AND C2 >= 75 AND C2 < $3 AND C1 >= 456 AND C1 <= 6789 + #[tokio::test] + async fn test_query_with_aggregation() { + const NUM_ACTUAL_COLUMNS: usize = 5; + + let rng = &mut thread_rng(); + let column_ids = random_vector::(NUM_ACTUAL_COLUMNS); + let primary_index = F::from_canonical_u64(column_ids[0]); + let secondary_index = F::from_canonical_u64(column_ids[1]); + let column_ids = ColumnIDs::new(column_ids[0], column_ids[1], column_ids[2..].to_vec()); + let min_query_primary = U256::from(456); + let max_query_primary = U256::from(6789); + let min_query_secondary = U256::from(75); + let max_query_secondary = U256::from(97686754); + // define placeholders + let first_placeholder_id = PlaceholderId::Generic(0); + let second_placeholder_id = PlaceholderIdentifier::Generic(1); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); + [first_placeholder_id, second_placeholder_id] + .iter() + .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); + // 3-rd placeholder is max query bound + 1, since the bound is C2 < $3, rather than C2 <= $3 + let third_placeholder_id = PlaceholderId::Generic(2); + placeholders.insert(third_placeholder_id, max_query_secondary + U256::from(1)); + let bounds = QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_query_secondary)), + Some( + QueryBoundSource::Operation(BasicOperation { + first_operand: InputOperand::Placeholder(third_placeholder_id), + second_operand: Some(InputOperand::Constant(U256::from(1))), + op: Operation::SubOp, + }), // the bound is computed as $3-1 since in the query we specified that C2 < $3, + // while the bound computed in the circuit is expected to represent the maximum value + // possible for C2 (i.e., C2 < $3 => C2 <= $3 - 1) + ), + ) + .unwrap(); + // build predicate operations + let mut predicate_operations = vec![]; + // C5 > 5 + let c5_comparison = BasicOperation { + first_operand: InputOperand::Column(4), + second_operand: Some(InputOperand::Constant(U256::from(5))), + op: Operation::GreaterThanOp, + }; + predicate_operations.push(c5_comparison); + // C1*C2 + let column_prod = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(1)), + op: Operation::MulOp, + }; + predicate_operations.push(column_prod); + // C3+C4 + let column_add = BasicOperation { + first_operand: InputOperand::Column(2), + second_operand: Some(InputOperand::Column(3)), + op: Operation::AddOp, + }; + predicate_operations.push(column_add); + // C1*C3 <= C4 + C5 + let expr_comparison = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &column_prod) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &column_add) + .unwrap(), + )), + op: Operation::LessThanOrEqOp, + }; + predicate_operations.push(expr_comparison); + // C3 == $2 + let placeholder_eq = BasicOperation { + first_operand: InputOperand::Column(2), + second_operand: Some(InputOperand::Placeholder(second_placeholder_id)), + op: Operation::EqOp, + }; + predicate_operations.push(placeholder_eq); + // c5_comparison AND expr_comparison + let and_comparisons = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &c5_comparison) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &expr_comparison) + .unwrap(), + )), + op: Operation::AndOp, + }; + predicate_operations.push(and_comparisons); + // final filtering predicate: and_comparisons OR placeholder_eq + let predicate = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &and_comparisons) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &placeholder_eq) + .unwrap(), + )), + op: Operation::OrOp, + }; + predicate_operations.push(predicate); + + // result computations operations + let mut result_operations = vec![]; + // C1*C2 + let column_prod = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(1)), + op: Operation::MulOp, + }; + result_operations.push(column_prod); + // C2/C3 + let column_div = BasicOperation { + first_operand: InputOperand::Column(1), + second_operand: Some(InputOperand::Column(2)), + op: Operation::DivOp, + }; + result_operations.push(column_div); + let sub = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_div).unwrap(), + )), + op: Operation::SubOp, + }; + result_operations.push(sub); + // C1/C4 + let column_div_for_min = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(3)), + op: Operation::DivOp, + }; + result_operations.push(column_div_for_min); + // C1/C4 + 3 + let add_for_min = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_div_for_min) + .unwrap(), + ), + second_operand: Some(InputOperand::Constant(U256::from(3))), + op: Operation::AddOp, + }; + result_operations.push(add_for_min); + // C4 - $1 + let column_placeholder = BasicOperation { + first_operand: InputOperand::Column(3), + second_operand: Some(InputOperand::Placeholder(first_placeholder_id)), + op: Operation::SubOp, + }; + result_operations.push(column_placeholder); + + // output items are all computed values in this query, expect for the last item + // which is a column + let output_items = vec![ + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &sub).unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &add_for_min) + .unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &column_placeholder) + .unwrap(), + ), + OutputItem::Column(4), + ]; + let output_ops: [F; 5] = [ + AggregationOperation::SumOp.to_field(), + AggregationOperation::AvgOp.to_field(), + AggregationOperation::MinOp.to_field(), + AggregationOperation::MaxOp.to_field(), + AggregationOperation::AvgOp.to_field(), + ]; + + let results = ResultStructure::new_for_query_with_aggregation( + result_operations, + output_items, + output_ops + .iter() + .map(|op| op.to_canonical_u64()) + .collect_vec(), + ) + .unwrap(); + + let [node_0, node_1, node_2] = build_test_tree(&bounds, &column_ids.to_vec()).await; + + let to_row_cells = |values: &[U256]| { + let column_cells = values + .iter() + .zip(column_ids.to_vec().iter()) + .map(|(&value, &id)| ColumnCell { value, id }) + .collect_vec(); + RowCells::new( + column_cells[0].clone(), + column_cells[1].clone(), + column_cells[2..].to_vec(), + ) + }; + + // run circuit over 3 consecutive rows: row 1C, row 2B and row 2D + let [node_1a, node_1b, node_1c, node_1d] = node_1 + .rows_tree + .iter() + .map(|n| n.node) + .collect_vec() + .try_into() + .unwrap(); + let path_1c = vec![(node_1a, ChildPosition::Right)]; + let node_1b_hash = HashOutput::from(node_1b.compute_node_hash(secondary_index)); + let siblings_1c = vec![Some(node_1b_hash)]; + let merkle_path_1c = MerklePathWithNeighborsGadget::new( + &path_1c, + &siblings_1c, + &node_1c, + [None, Some(node_1d)], + ) + .unwrap(); + let path_1 = vec![]; + let siblings_1 = vec![]; + let merkle_path_index_1 = MerklePathWithNeighborsGadget::new( + &path_1, + &siblings_1, + &node_1.node, + [Some(node_0.node), Some(node_2.node)], + ) + .unwrap(); + let row_cells_1c = to_row_cells(&node_1.rows_tree[2].values); + let row_1c = + RowProcessingGadgetInputs::new(merkle_path_1c, merkle_path_index_1, &row_cells_1c) + .unwrap(); + + let [node_2a, node_2b, node_2c, node_2d] = node_2 + .rows_tree + .iter() + .map(|n| n.node) + .collect_vec() + .try_into() + .unwrap(); + let path_2d = vec![ + (node_2b, ChildPosition::Right), + (node_2a, ChildPosition::Left), + ]; + let node_2c_hash = HashOutput::from(node_2c.compute_node_hash(secondary_index)); + let siblings_2d = vec![Some(node_2c_hash), None]; + let merkle_path_2d = + MerklePathWithNeighborsGadget::new(&path_2d, &siblings_2d, &node_2d, [None, None]) + .unwrap(); + let path_2 = vec![(node_1.node, ChildPosition::Right)]; + let node_0_hash = HashOutput::from(node_0.node.compute_node_hash(primary_index)); + let siblings_2 = vec![Some(node_0_hash)]; + let merkle_path_index_2 = + MerklePathWithNeighborsGadget::new(&path_2, &siblings_2, &node_2.node, [None, None]) + .unwrap(); + + let row_cells_2d = to_row_cells(&node_2.rows_tree[3].values); + + let row_2d = + RowProcessingGadgetInputs::new(merkle_path_2d, merkle_path_index_2, &row_cells_2d) + .unwrap(); + + let path_2b = vec![(node_2a, ChildPosition::Left)]; + let siblings_2b = vec![None]; + let merkle_path_2b = MerklePathWithNeighborsGadget::new( + &path_2b, + &siblings_2b, + &node_2b, + [Some(node_2c), Some(node_2d)], + ) + .unwrap(); + + let row_cells_2b = to_row_cells(&node_2.rows_tree[1].values); + + let row_2b = + RowProcessingGadgetInputs::new(merkle_path_2b, merkle_path_index_2, &row_cells_2b) + .unwrap(); + + let circuit = RowChunkProcessingCircuit::< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::new( + vec![row_1c.clone(), row_2b.clone(), row_2d.clone()], + &column_ids, + &predicate_operations, + &placeholders, + &bounds, + &results, + ) + .unwrap(); + + // compute placeholder hash for `circuit` + let placeholder_hash_ids = UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::ids_for_placeholder_hash( + &predicate_operations, &results, &placeholders, &bounds + ) + .unwrap(); + let placeholder_hash = + placeholder_hash(&placeholder_hash_ids, &placeholders, &bounds).unwrap(); + + let proof = run_circuit::(circuit); + // check public inputs + let pis = PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); + + let root = node_1.node.compute_node_hash(primary_index); + assert_eq!(root, pis.tree_hash(),); + assert_eq!(&pis.operation_ids()[..output_ops.len()], &output_ops,); + + // closure to compute predicate value and output values for a given row with cells `row_cells`. + // Return also a flag sepcifying whether arithmetic errors have occurred during the computation or not + + // compute predicate value and output values for each of the 3 rows + let (predicate_value_1c, err_1c, out_values_1c) = + compute_output_values_for_row::( + &row_cells_1c, + &predicate_operations, + &results, + &placeholders, + ); + let (predicate_value_2b, err_2b, out_values_2b) = + compute_output_values_for_row::( + &row_cells_2b, + &predicate_operations, + &results, + &placeholders, + ); + let (predicate_value_2d, err_2d, out_values_2d) = + compute_output_values_for_row::( + &row_cells_2d, + &predicate_operations, + &results, + &placeholders, + ); + + // aggregate out_values of the 3 rows + let (expected_outputs, expected_err) = { + let outputs = [out_values_1c, out_values_2b, out_values_2d]; + let mut num_overflows = 0; + let outputs = output_ops + .into_iter() + .enumerate() + .map(|(i, op)| { + let (out, overflows) = aggregate_output_values(i, &outputs, op); + num_overflows += overflows; + U256::from_fields(CurveOrU256::::from_slice(&out).to_u256_raw()) + }) + .collect_vec(); + (outputs, num_overflows != 0) + }; + + let computational_hash = ComputationalHash::from_bytes( + (&Identifiers::computational_hash_universal_circuit( + &column_ids, + &predicate_operations, + &results, + Some(bounds.min_query_secondary().into()), + Some(bounds.max_query_secondary().into()), + ) + .unwrap()) + .into(), + ); + + // compute expected left boundary row of the proven chunk: should correspond to row_1C + let left_boundary_row = { + // predecessor is node_1A, and it's in the path + let predecessor_info_1c = NeighborInfo::new( + node_1a.value, + Some(node_1a.compute_node_hash(secondary_index)), + ); + // successor is node_1D, and it's not in the path + let successor_info_1c = NeighborInfo::new(node_1d.value, None); + let row_1c_info = BoundaryRowNodeInfo { + end_node_hash: node_1c.compute_node_hash(secondary_index), + predecessor_info: predecessor_info_1c, + successor_info: successor_info_1c, + }; + // predecessor is node_0, and it's not in the path + let predecessor_index_1 = NeighborInfo::new(node_0.node.value, None); + // successor is node_2, and it's not in the path + let successor_index_1 = NeighborInfo::new(node_2.node.value, None); + let index_1_info = BoundaryRowNodeInfo { + end_node_hash: node_1.node.compute_node_hash(primary_index), + predecessor_info: predecessor_index_1, + successor_info: successor_index_1, + }; + BoundaryRowData { + row_node_info: row_1c_info, + index_node_info: index_1_info, + } + }; + // compute expected right boundary row of the proven chunk: should correspond to row_2D + let right_boundary_row = { + // predecessor is node_2B, and it's in the path + let predecessor_2d = NeighborInfo::new( + node_2b.value, + Some(node_2b.compute_node_hash(secondary_index)), + ); + // successor is node_2A, and it's in the path + let successor_2d = NeighborInfo::new( + node_2a.value, + Some(node_2a.compute_node_hash(secondary_index)), + ); + let row_2d_info = BoundaryRowNodeInfo { + end_node_hash: node_2d.compute_node_hash(secondary_index), + predecessor_info: predecessor_2d, + successor_info: successor_2d, + }; + + // predecessor is node 1, and it's in the path + let predecessor_index_2 = NeighborInfo::new( + node_1.node.value, + Some(node_1.node.compute_node_hash(primary_index)), + ); + // no successor + let successor_index_2 = NeighborInfo::new_dummy_successor(); + let index_2_info = BoundaryRowNodeInfo { + end_node_hash: node_2.node.compute_node_hash(primary_index), + predecessor_info: predecessor_index_2, + successor_info: successor_index_2, + }; + + BoundaryRowData { + row_node_info: row_2d_info, + index_node_info: index_2_info, + } + }; + + assert_eq!(pis.overflow_flag(), err_1c | err_2b | err_2d | expected_err); + assert_eq!( + pis.num_matching_rows(), + F::from_canonical_u8( + predicate_value_1c as u8 + predicate_value_2b as u8 + predicate_value_2d as u8 + ), + ); + assert_eq!(pis.first_value_as_u256(), expected_outputs[0],); + assert_eq!( + expected_outputs[1..], + pis.values()[..expected_outputs.len() - 1], + ); + // check boundary rows + assert_eq!(pis.to_left_row_raw(), &left_boundary_row.to_fields(),); + assert_eq!(pis.to_right_row_raw(), &right_boundary_row.to_fields(),); + + assert_eq!(pis.min_primary(), min_query_primary,); + assert_eq!(pis.max_primary(), max_query_primary,); + assert_eq!(pis.min_secondary(), min_query_secondary,); + assert_eq!(pis.max_secondary(), max_query_secondary,); + assert_eq!(pis.computational_hash(), computational_hash,); + assert_eq!(pis.placeholder_hash(), placeholder_hash,); + + // negative test: check that we cannot add an out of range row to the proven rows. + // We try to add row 2C to the proven rows + let path_2c = vec![ + (node_2b, ChildPosition::Left), + (node_2a, ChildPosition::Left), + ]; + let node_2d_hash = HashOutput::from(node_2d.compute_node_hash(secondary_index)); + let siblings_2c = vec![Some(node_2d_hash), None]; + let merkle_path_2c = + MerklePathWithNeighborsGadget::new(&path_2c, &siblings_2c, &node_2c, [None, None]) + .unwrap(); + + let row_cells_2c = to_row_cells(&node_2.rows_tree[2].values); + + let row_2c = + RowProcessingGadgetInputs::new(merkle_path_2c, merkle_path_index_2, &row_cells_2c) + .unwrap(); + + let circuit = RowChunkProcessingCircuit::< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::new( + vec![row_1c, row_2c, row_2b, row_2d], + &column_ids, + &predicate_operations, + &placeholders, + &bounds, + &results, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail when aggregating row with secondary index out of range" + ) + } + + #[tokio::test] + // SELECT C1*C2 > 45, (C3+C7)/C4, C7, (C5-C6)%C1, C1/C5 - $1 FROM T WHERE ((NOT C3 != 42) OR C1*C2 <= C4/C6-C7 XOR C5 < $2) AND C2 >= $3 AND C2 < 44 AND C1 >= 523 AND C1 <= 657 + async fn test_query_without_aggregation() { + const NUM_ACTUAL_COLUMNS: usize = 7; + + let rng = &mut thread_rng(); + let column_ids = random_vector::(NUM_ACTUAL_COLUMNS); + let primary_index = F::from_canonical_u64(column_ids[0]); + let secondary_index = F::from_canonical_u64(column_ids[1]); + let column_ids = ColumnIDs::new(column_ids[0], column_ids[1], column_ids[2..].to_vec()); + let min_query_primary = U256::from(523); + let max_query_primary = U256::from(657); + let min_query_secondary = U256::from(42); + let max_query_secondary = U256::from(43); + // define placeholders + let first_placeholder_id = PlaceholderId::Generic(0); + let second_placeholder_id = PlaceholderIdentifier::Generic(1); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); + [first_placeholder_id, second_placeholder_id] + .iter() + .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); + // 3-rd placeholder is the min query bound + let third_placeholder_id = PlaceholderId::Generic(2); + placeholders.insert(third_placeholder_id, min_query_secondary); + let query_bounds = QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Placeholder(third_placeholder_id)), + Some(QueryBoundSource::Constant(max_query_secondary)), + ) + .unwrap(); + + // build predicate operations + let mut predicate_operations = vec![]; + // C3 != 42 + let c5_comparison = BasicOperation { + first_operand: InputOperand::Column(2), + second_operand: Some(InputOperand::Constant(U256::from(42))), + op: Operation::NeOp, + }; + predicate_operations.push(c5_comparison); + // C1*C2 + let column_prod = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(1)), + op: Operation::MulOp, + }; + predicate_operations.push(column_prod); + // C4/C6 + let column_div = BasicOperation { + first_operand: InputOperand::Column(3), + second_operand: Some(InputOperand::Column(5)), + op: Operation::DivOp, + }; + predicate_operations.push(column_div); + // C4/C6 - C7 + let expr_add = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &column_div) + .unwrap(), + ), + second_operand: Some(InputOperand::Column(6)), + op: Operation::SubOp, + }; + predicate_operations.push(expr_add); + // C1*C2 <= C4/C6 - C7 + let expr_comparison = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &column_prod) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &expr_add) + .unwrap(), + )), + op: Operation::LessThanOrEqOp, + }; + predicate_operations.push(expr_comparison); + // C5 < $2 + let placeholder_cmp = BasicOperation { + first_operand: InputOperand::Column(4), + second_operand: Some(InputOperand::Placeholder(second_placeholder_id)), + op: Operation::LessThanOp, + }; + predicate_operations.push(placeholder_cmp); + // NOT c5_comparison + let not_c5 = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &c5_comparison) + .unwrap(), + ), + second_operand: None, + op: Operation::NotOp, + }; + predicate_operations.push(not_c5); + // NOT c5_comparison OR expr_comparison + let or_comparisons = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, ¬_c5).unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &expr_comparison) + .unwrap(), + )), + op: Operation::OrOp, + }; + predicate_operations.push(or_comparisons); + // final filtering predicate: or_comparisons XOR placeholder_cmp + let predicate = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &or_comparisons) + .unwrap(), + ), + second_operand: Some(InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&predicate_operations, &placeholder_cmp) + .unwrap(), + )), + op: Operation::XorOp, + }; + predicate_operations.push(predicate); + + // result computations operations + let mut result_operations = vec![]; + // C1*C2 + let column_prod = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(1)), + op: Operation::MulOp, + }; + result_operations.push(column_prod); + // C1*C2 < 45 + let column_cmp = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), + ), + second_operand: Some(InputOperand::Constant(U256::from(45))), + op: Operation::LessThanOp, + }; + result_operations.push(column_cmp); + // C3+C7 + let column_add = BasicOperation { + first_operand: InputOperand::Column(2), + second_operand: Some(InputOperand::Column(6)), + op: Operation::AddOp, + }; + result_operations.push(column_add); + // (C3+C7)/C4 + let expr_div = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_add).unwrap(), + ), + second_operand: Some(InputOperand::Column(3)), + op: Operation::DivOp, + }; + result_operations.push(expr_div); + // C5 - C6 + let column_sub = BasicOperation { + first_operand: InputOperand::Column(4), + second_operand: Some(InputOperand::Column(5)), + op: Operation::SubOp, + }; + result_operations.push(column_sub); + // (C5 - C6) % C1 + let column_mod = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_sub).unwrap(), + ), + second_operand: Some(InputOperand::Column(0)), + op: Operation::ModOp, + }; + result_operations.push(column_mod); + // C1/C5 + let column_div = BasicOperation { + first_operand: InputOperand::Column(0), + second_operand: Some(InputOperand::Column(4)), + op: Operation::DivOp, + }; + result_operations.push(column_div); + // C1/C5 - $1 + let sub_placeholder = BasicOperation { + first_operand: InputOperand::PreviousValue( + BasicOperation::locate_previous_operation(&result_operations, &column_div).unwrap(), + ), + second_operand: Some(InputOperand::Placeholder(first_placeholder_id)), + op: Operation::SubOp, + }; + result_operations.push(sub_placeholder); + + let output_items = vec![ + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &column_cmp).unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &expr_div).unwrap(), + ), + OutputItem::Column(6), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &column_mod).unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &sub_placeholder) + .unwrap(), + ), + ]; + let output_ids = vec![F::rand(); output_items.len()]; + let results = ResultStructure::new_for_query_no_aggregation( + result_operations, + output_items, + output_ids + .iter() + .map(|id| id.to_canonical_u64()) + .collect_vec(), + false, + ) + .unwrap(); + + let [node_0, node_1, node_2] = build_test_tree(&query_bounds, &column_ids.to_vec()).await; + + let to_row_cells = |values: &[U256]| { + let column_cells = values + .iter() + .zip(column_ids.to_vec().iter()) + .map(|(&value, &id)| ColumnCell { value, id }) + .collect_vec(); + RowCells::new( + column_cells[0].clone(), + column_cells[1].clone(), + column_cells[2..].to_vec(), + ) + }; + + // run circuit over 4 consecutive rows: row 1A, row 1C, row 2B and row 2D + let [node_1a, node_1b, node_1c, node_1d] = node_1 + .rows_tree + .iter() + .map(|n| n.node) + .collect_vec() + .try_into() + .unwrap(); + let path_1a = vec![]; + let siblings_1a = vec![]; + let merkle_path_1a = MerklePathWithNeighborsGadget::new( + &path_1a, + &siblings_1a, + &node_1a, + [Some(node_1b), Some(node_1c)], + ) + .unwrap(); + let path_1 = vec![]; + let siblings_1 = vec![]; + let merkle_path_index_1 = MerklePathWithNeighborsGadget::new( + &path_1, + &siblings_1, + &node_1.node, + [Some(node_0.node), Some(node_2.node)], + ) + .unwrap(); + + let row_cells_1a = to_row_cells(&node_1.rows_tree[0].values); + let row_1a = + RowProcessingGadgetInputs::new(merkle_path_1a, merkle_path_index_1, &row_cells_1a) + .unwrap(); + + let path_1c = vec![(node_1a, ChildPosition::Right)]; + let node_1b_hash = HashOutput::from(node_1b.compute_node_hash(secondary_index)); + let siblings_1c = vec![Some(node_1b_hash)]; + let merkle_path_1c = MerklePathWithNeighborsGadget::new( + &path_1c, + &siblings_1c, + &node_1c, + [None, Some(node_1d)], + ) + .unwrap(); + + let row_cells_1c = to_row_cells(&node_1.rows_tree[2].values); + let row_1c = + RowProcessingGadgetInputs::new(merkle_path_1c, merkle_path_index_1, &row_cells_1c) + .unwrap(); + + let [node_2a, node_2b, node_2c, node_2d] = node_2 + .rows_tree + .iter() + .map(|n| n.node) + .collect_vec() + .try_into() + .unwrap(); + let path_2d = vec![ + (node_2b, ChildPosition::Right), + (node_2a, ChildPosition::Left), + ]; + let node_2c_hash = HashOutput::from(node_2c.compute_node_hash(secondary_index)); + let siblings_2d = vec![Some(node_2c_hash), None]; + let merkle_path_2d = + MerklePathWithNeighborsGadget::new(&path_2d, &siblings_2d, &node_2d, [None, None]) + .unwrap(); + let path_2 = vec![(node_1.node, ChildPosition::Right)]; + let node_0_hash = HashOutput::from(node_0.node.compute_node_hash(primary_index)); + let siblings_2 = vec![Some(node_0_hash)]; + let merkle_path_index_2 = + MerklePathWithNeighborsGadget::new(&path_2, &siblings_2, &node_2.node, [None, None]) + .unwrap(); + + let row_cells_2d = to_row_cells(&node_2.rows_tree[3].values); + + let row_2d = + RowProcessingGadgetInputs::new(merkle_path_2d, merkle_path_index_2, &row_cells_2d) + .unwrap(); + + let path_2b = vec![(node_2a, ChildPosition::Left)]; + let siblings_2b = vec![None]; + let merkle_path_2b = MerklePathWithNeighborsGadget::new( + &path_2b, + &siblings_2b, + &node_2b, + [Some(node_2c), Some(node_2d)], + ) + .unwrap(); + + let row_cells_2b = to_row_cells(&node_2.rows_tree[1].values); + + let row_2b = + RowProcessingGadgetInputs::new(merkle_path_2b, merkle_path_index_2, &row_cells_2b) + .unwrap(); + + let circuit = RowChunkProcessingCircuit::< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::new( + vec![ + row_1a.clone(), + row_1c.clone(), + row_2b.clone(), + row_2d.clone(), + ], + &column_ids, + &predicate_operations, + &placeholders, + &query_bounds, + &results, + ) + .unwrap(); + + // compute placeholder hash for `circuit` + let placeholder_hash_ids = UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); + let placeholder_hash = + placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); + + let proof = run_circuit::(circuit); + // check public inputs + let pis = PublicInputsQueryCircuits::::from_slice(&proof.public_inputs); + + let root = node_1.node.compute_node_hash(primary_index); + assert_eq!(root, pis.tree_hash(),); + assert_eq!( + >::to_field(&AggregationOperation::IdOp), + pis.operation_ids()[0] + ); + // aggregation operation in the other MAX_NUM_RESULTS -1 slots are dummy ones, as in queries + // without aggregation we accumulate all the results in the first output value, + // and so we don't care about the other ones + assert_eq!( + [>::to_field(&AggregationOperation::default()); + MAX_NUM_RESULTS - 1], + pis.operation_ids()[1..] + ); + + // closure to compute predicate value and accumulator of output values for a given row with cells `row_cells`. + // Return also a flag sepcifying whether arithmetic errors have occurred during the computation or not + let compute_output_values = async |row_cells: &RowCells| { + let column_values = row_cells + .to_cells() + .into_iter() + .map(|cell| cell.value) + .collect_vec(); + let (res, predicate_err) = BasicOperation::compute_operations( + &predicate_operations, + &column_values, + &placeholders, + ) + .unwrap(); + let predicate_value = res.last().unwrap().try_into_bool().unwrap(); + + let (res, result_err) = results + .compute_output_values(&column_values, &placeholders) + .unwrap(); + let out_cells = res + .iter() + .zip(output_ids.iter()) + .map(|(value, id)| TestCell::new(*value, *id)) + .collect_vec(); + let output_acc = if predicate_value { + // if predicate value is satisfied, then we expose the accumulator of all the output values + // to be returned for the current row + map_to_curve_point( + &once(out_cells[0].id) + .chain(out_cells[0].value.to_fields()) + .chain(once( + out_cells.get(1).map(|cell| cell.id).unwrap_or_default(), + )) + .chain( + out_cells + .get(1) + .map(|cell| cell.value) + .unwrap_or_default() + .to_fields(), + ) + .chain( + compute_cells_tree_hash( + out_cells.get(2..).unwrap_or_default().to_vec(), + ) + .await + .to_vec(), + ) + .collect_vec(), + ) + } else { + // otherwise, we expose the neutral point to ensure that the results for + // the current record are not included in the accumulator of all the results + // of the query + Point::NEUTRAL + }; + (predicate_value, predicate_err | result_err, output_acc) + }; + + // compute predicate value and accumulator of output values for each of the 4 rows being proven + let (predicate_value_1a, err_1a, acc_1a) = compute_output_values(&row_cells_1a).await; + let (predicate_value_1c, err_1c, acc_1c) = compute_output_values(&row_cells_1c).await; + let (predicate_value_2b, err_2b, acc_2b) = compute_output_values(&row_cells_2b).await; + let (predicate_value_2d, err_2d, acc_2d) = compute_output_values(&row_cells_2d).await; + + let computational_hash = ComputationalHash::from_bytes( + (&Identifiers::computational_hash_universal_circuit( + &column_ids, + &predicate_operations, + &results, + Some(query_bounds.min_query_secondary().into()), + Some(query_bounds.max_query_secondary().into()), + ) + .unwrap()) + .into(), + ); + // compute expected left boundary row of the proven chunk: should correspond to row_1A + let left_boundary_row = { + // predecessor is node_1B, and it's not in the path + let predecessor_info_1a = NeighborInfo::new(node_1b.value, None); + // successor is node_1C, and it's not in the path + let successor_info_1a = NeighborInfo::new(node_1c.value, None); + let row_1a_info = BoundaryRowNodeInfo { + end_node_hash: node_1a.compute_node_hash(secondary_index), + predecessor_info: predecessor_info_1a, + successor_info: successor_info_1a, + }; + // predecessor is node_0, and it's not in the path + let predecessor_index_1 = NeighborInfo::new(node_0.node.value, None); + // successor is node_2, and it's not in the path + let successor_index_1 = NeighborInfo::new(node_2.node.value, None); + let index_1_info = BoundaryRowNodeInfo { + end_node_hash: node_1.node.compute_node_hash(primary_index), + predecessor_info: predecessor_index_1, + successor_info: successor_index_1, + }; + BoundaryRowData { + row_node_info: row_1a_info, + index_node_info: index_1_info, + } + }; + // compute expected right boundary row of the proven chunk: should correspond to row_2D + let right_boundary_row = { + // predecessor is node_2B, and it's in the path + let predecessor_2d = NeighborInfo::new( + node_2b.value, + Some(node_2b.compute_node_hash(secondary_index)), + ); + // successor is node_2A, and it's in the path + let successor_2d = NeighborInfo::new( + node_2a.value, + Some(node_2a.compute_node_hash(secondary_index)), + ); + let row_2d_info = BoundaryRowNodeInfo { + end_node_hash: node_2d.compute_node_hash(secondary_index), + predecessor_info: predecessor_2d, + successor_info: successor_2d, + }; + + // predecessor is node 1, and it's in the path + let predecessor_index_2 = NeighborInfo::new( + node_1.node.value, + Some(node_1.node.compute_node_hash(primary_index)), + ); + // no successor + let successor_index_2 = NeighborInfo::new_dummy_successor(); + let index_2_info = BoundaryRowNodeInfo { + end_node_hash: node_2.node.compute_node_hash(primary_index), + predecessor_info: predecessor_index_2, + successor_info: successor_index_2, + }; + + BoundaryRowData { + row_node_info: row_2d_info, + index_node_info: index_2_info, + } + }; + + assert_eq!(pis.overflow_flag(), err_1a | err_1c | err_2b | err_2d,); + assert_eq!( + pis.num_matching_rows(), + F::from_canonical_u8( + predicate_value_1a as u8 + + predicate_value_1c as u8 + + predicate_value_2b as u8 + + predicate_value_2d as u8 + ), + ); + assert_eq!( + pis.first_value_as_curve_point(), + (acc_1a + acc_1c + acc_2b + acc_2d).to_weierstrass(), + ); + // The other MAX_NUM_RESULTS -1 output values are dummy ones, as in queries + // without aggregation we accumulate all the results in the first output value, + // and so we don't care about the other ones + assert_eq!(array::from_fn(|_| U256::ZERO), pis.values()); + // check boundary rows + assert_eq!(pis.to_left_row_raw(), &left_boundary_row.to_fields(),); + assert_eq!(pis.to_right_row_raw(), &right_boundary_row.to_fields(),); + + assert_eq!(pis.min_primary(), min_query_primary,); + assert_eq!(pis.max_primary(), max_query_primary,); + assert_eq!(pis.min_secondary(), min_query_secondary,); + assert_eq!(pis.max_secondary(), max_query_secondary,); + assert_eq!(pis.computational_hash(), computational_hash,); + assert_eq!(pis.placeholder_hash(), placeholder_hash,); + + // negative test: check that we cannot add nodes in the index tree outside of the range. We try to add + // row 0B to the set of proven rows + let [node_0a, node_0b, node_0c] = node_0 + .rows_tree + .iter() + .map(|n| n.node) + .collect_vec() + .try_into() + .unwrap(); + let path_0b = vec![(node_0a, ChildPosition::Left)]; + let siblings_0b = vec![None]; + let merkle_path_0b = MerklePathWithNeighborsGadget::new( + &path_0b, + &siblings_0b, + &node_0b, + [None, Some(node_0c)], + ) + .unwrap(); + let path_0 = vec![(node_1.node, ChildPosition::Left)]; + let node_2_hash = HashOutput::from(node_2.node.compute_node_hash(primary_index)); + let siblings_0 = vec![Some(node_2_hash)]; + let merkle_path_index_2 = + MerklePathWithNeighborsGadget::new(&path_0, &siblings_0, &node_0.node, [None, None]) + .unwrap(); + + let row_cells_0b = to_row_cells(&node_0.rows_tree[1].values); + let row_0b = + RowProcessingGadgetInputs::new(merkle_path_0b, merkle_path_index_2, &row_cells_0b) + .unwrap(); + + let circuit = RowChunkProcessingCircuit::< + NUM_ROWS, + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::new( + vec![row_0b, row_1a, row_1c, row_2b, row_2d], + &column_ids, + &predicate_operations, + &placeholders, + &query_bounds, + &results, + ) + .unwrap(); + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail when aggregating row with primary index out of range" + ) + } +} diff --git a/verifiable-db/src/query/computational_hash_ids.rs b/verifiable-db/src/query/computational_hash_ids.rs index 42c135a2c..73a1e1be1 100644 --- a/verifiable-db/src/query/computational_hash_ids.rs +++ b/verifiable-db/src/query/computational_hash_ids.rs @@ -13,7 +13,7 @@ use mp2_common::{ poseidon::{empty_poseidon_hash, H}, types::{CBuilder, HashOutput}, u256::UInt256Target, - utils::{Fieldable, FromFields, SelectHashBuilder, ToFields, ToTargets}, + utils::{Fieldable, FromFields, HashBuilder, ToFields, ToTargets}, CHasher, F, }; use plonky2::{ @@ -31,14 +31,14 @@ use serde::{Deserialize, Serialize}; use crate::revelation::placeholders_check::placeholder_ids_hash; use super::{ - aggregation::QueryBoundSource, universal_circuit::{ universal_circuit_inputs::{ BasicOperation, InputOperand, OutputItem, PlaceholderIdsSet, ResultStructure, }, - universal_query_circuit::QueryBound, + universal_query_gadget::QueryBound, ComputationalHash, ComputationalHashTarget, }, + utils::QueryBoundSource, }; pub enum Identifiers { @@ -234,7 +234,7 @@ impl ToField for Identifiers { } } /// Data structure to provide identifiers of columns of a table to compute computational hash -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct ColumnIDs { pub(crate) primary: F, pub(crate) secondary: F, @@ -250,12 +250,31 @@ impl ColumnIDs { } } + pub fn primary_column(&self) -> u64 { + self.primary.to_canonical_u64() + } + + pub fn secondary_column(&self) -> u64 { + self.secondary.to_canonical_u64() + } + + pub fn non_indexed_columns(&self) -> Vec { + self.rest + .iter() + .map(|id| id.to_canonical_u64()) + .collect_vec() + } + pub(crate) fn to_vec(&self) -> Vec { [self.primary, self.secondary] .into_iter() .chain(self.rest.clone()) .collect_vec() } + + pub(crate) fn num_columns(&self) -> usize { + self.rest.len() + 2 + } } #[derive(Clone, Debug, Copy, Default)] diff --git a/verifiable-db/src/query/merkle_path.rs b/verifiable-db/src/query/merkle_path.rs index c442aa51e..571cc8bb2 100644 --- a/verifiable-db/src/query/merkle_path.rs +++ b/verifiable-db/src/query/merkle_path.rs @@ -9,24 +9,28 @@ use mp2_common::{ hash::hash_maybe_first, poseidon::empty_poseidon_hash, serialization::{ - deserialize_array, deserialize_long_array, serialize_array, serialize_long_array, + circuit_data_serialization::SerializableRichField, deserialize, deserialize_array, + deserialize_long_array, serialize, serialize_array, serialize_long_array, }, - types::HashOutput, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{SelectHashBuilder, ToTargets}, + types::{CBuilder, HashOutput}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256, NUM_LIMBS}, + utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets, TryIntoBool}, D, F, }; +use mp2_test::utils::gen_random_field_hash; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget}, + field::types::Field, + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, }; +use rand::Rng; use serde::{Deserialize, Serialize}; -use super::aggregation::{ChildPosition, NodeInfo}; +use super::utils::{ChildPosition, NodeInfo, NodeInfoTarget}; #[derive(Clone, Debug, Serialize, Deserialize)] /// Input wires for Merkle path verification gadget @@ -73,6 +77,46 @@ where )] is_real_node: [BoolTarget; MAX_DEPTH - 1], } +#[derive(Clone, Debug, Serialize, Deserialize)] +/// Input wires related to the data of the end node whose membership in the tree +/// is proven with `MerklePathWithNeighborsGadget`. +pub struct EndNodeInputs { + // minimum of the end node. It is necessary to recompute the hash of the node + // inside the circuit + node_min: UInt256Target, + // maximum of the end node. It is necessary to recompute the hash of the node + // inside the circuit + node_max: UInt256Target, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + // Flag specifying whether the end node has a left child + left_child_exists: BoolTarget, + // The data about the left child of the node, which might be necessary to + // extract the value of the predecessor of the end node + left_child_info: NodeInfoTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + // Flag specifying whether the end node has a right child + right_child_exists: BoolTarget, + // The data about the right child of the node, which might be necessary to + // extract the value of the successor of the end node + right_child_info: NodeInfoTarget, +} + +impl EndNodeInputs { + pub(crate) fn build(b: &mut CBuilder) -> Self { + let [node_min, node_max] = b.add_virtual_u256_arr_unsafe(); + let [left_child_exists, right_child_exists] = + array::from_fn(|_| b.add_virtual_bool_target_safe()); + + Self { + node_min, + node_max, + left_child_exists, + left_child_info: NodeInfoTarget::build_unsafe(b), + right_child_exists, + right_child_info: NodeInfoTarget::build_unsafe(b), + } + } +} #[derive(Clone, Debug)] /// Set of input/output wires built by merkle path verification gadget @@ -84,6 +128,115 @@ where /// Recomputed root for the Merkle path pub(crate) root: HashOutTarget, } +#[derive(Clone, Debug)] +/// Target containing data about a neighbor of a node (neighbor can be +/// either the predecessor or the successor of a node) +pub struct NeighborInfoTarget { + /// Boolean flag specifying whether the node has the given neighbor + pub(crate) is_found: BoolTarget, + /// Boolean flag specifying whether the neighbor is in the path from the + /// given node up to the root + pub(crate) is_in_path: BoolTarget, + /// Value of the neighbor (if the neighbor exists, otherwise a dummy value can be employed) + pub(crate) value: UInt256Target, + /// Hash of the neighbor node (if the neighbor exists, otherwise a dummy value can be employed) + pub(crate) hash: HashOutTarget, +} + +impl NeighborInfoTarget { + pub(crate) fn new_dummy_predecessor(b: &mut CircuitBuilder) -> Self { + Self { + is_found: b._false(), + is_in_path: b._true(), // the circuit still looks at the predecessor in the path + value: b.zero_u256(), + hash: b.constant_hash(*empty_poseidon_hash()), + } + } + + pub(crate) fn new_dummy_successor(b: &mut CircuitBuilder) -> Self { + Self { + is_found: b._false(), + is_in_path: b._true(), // the circuit still looks at the predecessor in the path + value: b.constant_u256(U256::MAX), + hash: b.constant_hash(*empty_poseidon_hash()), + } + } +} + +impl ToTargets for NeighborInfoTarget { + fn to_targets(&self) -> Vec { + once(self.is_found.target) + .chain(once(self.is_in_path.target)) + .chain(self.value.to_targets()) + .chain(self.hash.to_targets()) + .collect() + } +} + +impl FromTargets for NeighborInfoTarget { + const NUM_TARGETS: usize = 2 + NUM_LIMBS + NUM_HASH_OUT_ELTS; + + fn from_targets(t: &[Target]) -> Self { + Self { + is_found: BoolTarget::new_unsafe(t[0]), + is_in_path: BoolTarget::new_unsafe(t[1]), + value: UInt256Target::from_targets(&t[2..]), + hash: HashOutTarget::from_targets(&t[2 + NUM_LIMBS..]), + } + } +} + +impl SelectTarget for NeighborInfoTarget { + fn select, const D: usize>( + b: &mut CircuitBuilder, + cond: &BoolTarget, + first: &Self, + second: &Self, + ) -> Self { + Self { + is_found: BoolTarget::new_unsafe(b.select( + *cond, + first.is_found.target, + second.is_found.target, + )), + is_in_path: BoolTarget::new_unsafe(b.select( + *cond, + first.is_in_path.target, + second.is_in_path.target, + )), + value: b.select_u256(*cond, &first.value, &second.value), + hash: b.select_hash(*cond, &first.hash, &second.hash), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +/// Set of input wires for the merkle path with neighbors gadget +pub struct MerklePathWithNeighborsTargetInputs +where + [(); MAX_DEPTH - 1]:, +{ + pub(crate) path_inputs: MerklePathTargetInputs, + pub(crate) end_node_inputs: EndNodeInputs, +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +/// Set of input/output wires built by merkle path with neighbors gadget +pub struct MerklePathWithNeighborsTarget +where + [(); MAX_DEPTH - 1]:, +{ + pub(crate) inputs: MerklePathWithNeighborsTargetInputs, + /// Recomputed root for the Merkle path + pub(crate) root: HashOutTarget, + /// Hash of the node at the end of the path + pub(crate) end_node_hash: HashOutTarget, + /// Info about the predecessor of the node at the end of the path + pub(crate) predecessor_info: NeighborInfoTarget, + /// Info about the successor of the node at the end of the path + pub(crate) successor_info: NeighborInfoTarget, +} #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct MerklePathGadget @@ -185,7 +338,7 @@ where siblings .get(i) .and_then(|sibling| { - sibling.and_then(|node_hash| Some(HashOut::from_bytes(node_hash.as_ref()))) + sibling.map(|node_hash| HashOut::from_bytes((&node_hash).into())) }) .unwrap_or(*empty_poseidon_hash()) }); @@ -201,15 +354,33 @@ where }) } - /// Build wires for `MerklePathGadget`. The requrested inputs are: - /// - `start_node`: The hash of the first node in the path + /// Build wires for `MerklePathGadget`. The required inputs are: + /// - `end_node`: The hash of the first node in the path /// - `index_id`: Integer identifier of the index column to be placed in the hash /// of the nodes of the path pub fn build( b: &mut CircuitBuilder, - start_node: HashOutTarget, + end_node: HashOutTarget, index_id: Target, ) -> MerklePathTarget { + let (inputs, path) = Self::build_path(b, end_node, index_id); + + MerklePathTarget { + inputs, + root: *path.last().unwrap(), + } + } + + /// Gadget to compute the hashes of all the nodes in the path from `end_node` to the root of + /// a Merkle-tree + fn build_path( + b: &mut CircuitBuilder, + end_node: HashOutTarget, + index_id: Target, + ) -> ( + MerklePathTargetInputs, + [HashOutTarget; MAX_DEPTH - 1], + ) { let is_left_child = array::from_fn(|_| b.add_virtual_bool_target_unsafe()); let [sibling_hash, embedded_tree_hash] = [0, 1].map(|_| array::from_fn(|_| b.add_virtual_hash())); @@ -217,8 +388,8 @@ where |_| b.add_virtual_u256_arr_unsafe(), // unsafe should be ok since we just need to hash them ); let is_real_node = array::from_fn(|_| b.add_virtual_bool_target_safe()); - - let mut final_hash = start_node; + let mut final_hash = end_node; + let mut path_nodes = vec![]; for i in 0..MAX_DEPTH - 1 { let rest = node_min[i] .to_targets() @@ -236,20 +407,25 @@ where rest.as_slice(), )); final_hash = b.select_hash(is_real_node[i], &node_hash, &final_hash); + path_nodes.push(final_hash); } - MerklePathTarget { - inputs: MerklePathTargetInputs { - is_left_child, - sibling_hash, - node_min, - node_max, - node_value, - embedded_tree_hash, - is_real_node, - }, - root: final_hash, + let inputs = MerklePathTargetInputs { + is_left_child, + sibling_hash, + node_min, + node_max, + node_value, + embedded_tree_hash, + is_real_node, + }; + + // ensure there is always one node in the path even if `MAX_DEPTH=1` + if path_nodes.is_empty() { + path_nodes.push(end_node); } + + (inputs, path_nodes.try_into().unwrap()) } pub fn assign(&self, pw: &mut PartialWitness, wires: &MerklePathTargetInputs) { @@ -288,29 +464,328 @@ where } } +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct MerklePathWithNeighborsGadget +where + [(); MAX_DEPTH - 1]:, +{ + path_gadget: MerklePathGadget, + // minimum value of the node whose membership in in tree is + // being proven with this gadget (referred to as `end_node`). + // It is necessary to recompute the hash of the end node + end_node_min: U256, + // maximum value of the end node whose membership in the tree is + // being proven with this gadget (referred to as `end_node`). + // It is necessary to recompute the hash of the end node + end_node_max: U256, + // Data about the children of the end node whose membership in the + // tree is being proven with this gadget (referred to as `end_node`). + // Children data might be necessary to compute the value of the + // predecessor/successor of the end node + end_node_children: [Option; 2], +} + +impl MerklePathWithNeighborsGadget +where + [(); MAX_DEPTH - 1]:, +{ + /// Build a new instance of `Self`, representing the path from `end_node` to the root. + /// Such path is provided as input, altogether with the siblings of the nodes in such + /// path, if any. The method requires also the data about the children of `end_node`, + /// if any. + pub fn new( + path: &[(NodeInfo, ChildPosition)], + siblings: &[Option], + end_node: &NodeInfo, + end_node_children: [Option; 2], + ) -> Result { + let path_gadget = MerklePathGadget::new(path, siblings)?; + Ok(Self { + path_gadget, + end_node_min: end_node.min, + end_node_max: end_node.max, + end_node_children, + }) + } + + /// Build wires for `MerklePathGadget`. The required inputs are: + /// - `end_node_value`: Value stored in the first node in the path + /// - `end_node_tree_hash` : Hash of the embedded tree stored in the first node in the path + /// - `index_id`: Integer identifier of the index column to be placed in the hash + /// of the nodes of the path + pub fn build( + b: &mut CircuitBuilder, + end_node_value: UInt256Target, + end_node_tree_hash: HashOutTarget, + index_id: Target, + ) -> MerklePathWithNeighborsTarget { + let end_node_info = EndNodeInputs::build(b); + // compute end node hash + let left_child_hash = end_node_info.left_child_info.compute_node_hash(b, index_id); + let right_child_hash = end_node_info + .right_child_info + .compute_node_hash(b, index_id); + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let left_child_hash = b.select_hash( + end_node_info.left_child_exists, + &left_child_hash, + &empty_hash, + ); + let right_child_hash = b.select_hash( + end_node_info.right_child_exists, + &right_child_hash, + &empty_hash, + ); + let end_node = NodeInfoTarget { + embedded_tree_hash: end_node_tree_hash, + child_hashes: [left_child_hash, right_child_hash], + value: end_node_value, + min: end_node_info.node_min.clone(), + max: end_node_info.node_max.clone(), + }; + let end_node_hash = end_node.compute_node_hash(b, index_id); + let (inputs, path) = MerklePathGadget::build_path(b, end_node_hash, index_id); + // we need to initialize predecessor and successor data + let (mut predecessor_info, mut successor_info) = { + // the predecessor of end_node is an ancestor of end_node iff end_node has no left child + let is_predecessor_in_path = b.not(end_node_info.left_child_exists); + let zero_u256 = b.zero_u256(); + let max_u256 = b.constant_u256(U256::MAX); + // Initialize value of predecessor node of end_node to a dummy value if the predecessor node + // will be found in the path; otherwise, the predecessor_value is the maximum value in + // the subtree rooted in the left child of end_node + let predecessor_value = b.select_u256( + is_predecessor_in_path, + &zero_u256, + &end_node_info.left_child_info.max, + ); + // the predecessor value is already found if end_node has a left child + let predecessor_found = end_node_info.left_child_exists; + // Initialize predecessor node hash to a dummy value + let predecessor_hash = b.constant_hash(*empty_poseidon_hash()); + // build predecessor info + let predecessor_info = NeighborInfoTarget { + is_found: predecessor_found, + is_in_path: is_predecessor_in_path, + value: predecessor_value, + hash: predecessor_hash, + }; + + // the successor of end_node is an ancestor of end_node iff end_node has no right child + let is_successor_in_path = b.not(end_node_info.right_child_exists); + // Initialize value of successor node of end_node to a dummy value if the successor node + // will be found in the path; otherwise, successor_value is the minimum value in + // the subtree rooted in the right child of end_node + let successor_value = b.select_u256( + is_successor_in_path, + &max_u256, // set dummy value of success to `U256::MAX`, it allows to satisfy constraints of + // `are_consecutive_nodes` gadget in case the node has no successor in the tree + &end_node_info.right_child_info.min, + ); + // the successor value is already found if end_node has a right child + let successor_found = end_node_info.right_child_exists; + // Initialize successor node hash to a dummy value + let successor_hash = b.constant_hash(*empty_poseidon_hash()); + // build successor info + let successor_info = NeighborInfoTarget { + is_found: successor_found, + is_in_path: is_successor_in_path, + value: successor_value, + hash: successor_hash, + }; + (predecessor_info, successor_info) + }; + + #[allow(clippy::needless_range_loop)] + for i in 0..MAX_DEPTH - 1 { + // we need to look for the predecessor + let is_right_child = b.not(inputs.is_left_child[i]); + /* First, we determine if the current node is the predecessor */ + let mut is_current_node_predecessor = b.not(predecessor_info.is_found); // current node cannot + // be the predecessor if predecessor has already been found + is_current_node_predecessor = + b.and(is_current_node_predecessor, inputs.is_real_node[i]); // current node + // cannot be the predecessor if it's not a real node + is_current_node_predecessor = b.and(is_current_node_predecessor, is_right_child); // current node + // is the predecessor if the previous node in the path is its right child + // we update predecessor_info.hash if current node is the predecessor + predecessor_info.hash = b.select_hash( + is_current_node_predecessor, + &path[i], + &predecessor_info.hash, + ); + // we update predecessor_info.value if current node is the predecessor + predecessor_info.value = b.select_u256( + is_current_node_predecessor, + &inputs.node_value[i], + &predecessor_info.value, + ); + // set predecessor_info.is_found if current node is the predecessor + predecessor_info.is_found = + b.or(predecessor_info.is_found, is_current_node_predecessor); + + // we need to look for the successor + /* First, we determine if the current node is the successor */ + let mut is_current_node_successor = b.not(successor_info.is_found); // current node cannot + // be the successor if successor has already been found + is_current_node_successor = b.and(is_current_node_successor, inputs.is_real_node[i]); // current node + // cannot be the successor if it's not a real node + is_current_node_successor = b.and(is_current_node_successor, inputs.is_left_child[i]); // current node + // is the successor if the previous node in the path is its left child + // we update successor_info.hash if current node is the successor + successor_info.hash = + b.select_hash(is_current_node_successor, &path[i], &successor_info.hash); + // we update successor_info.value if current node is the successor + successor_info.value = b.select_u256( + is_current_node_successor, + &inputs.node_value[i], + &successor_info.value, + ); + // set successor_info.is_found if current node is the successor + successor_info.is_found = b.or(successor_info.is_found, is_current_node_successor); + } + + MerklePathWithNeighborsTarget { + inputs: MerklePathWithNeighborsTargetInputs { + path_inputs: inputs, + end_node_inputs: end_node_info, + }, + root: *path.last().unwrap(), + end_node_hash, + predecessor_info, + successor_info, + } + } + + pub fn assign( + &self, + pw: &mut PartialWitness, + wires: &MerklePathWithNeighborsTargetInputs, + ) { + self.path_gadget.assign(pw, &wires.path_inputs); + pw.set_u256_target_arr( + &[ + wires.end_node_inputs.node_min.clone(), + wires.end_node_inputs.node_max.clone(), + ], + &[self.end_node_min, self.end_node_max], + ); + pw.set_bool_target( + wires.end_node_inputs.left_child_exists, + self.end_node_children[0].is_some(), + ); + pw.set_bool_target( + wires.end_node_inputs.right_child_exists, + self.end_node_children[1].is_some(), + ); + let left_child_info = self.end_node_children[0].unwrap_or_default(); + let right_child_info = self.end_node_children[1].unwrap_or_default(); + wires + .end_node_inputs + .left_child_info + .set_target(pw, &left_child_info); + wires + .end_node_inputs + .right_child_info + .set_target(pw, &right_child_info); + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct NeighborInfo { + pub(crate) is_found: bool, + pub(crate) is_in_path: bool, + pub(crate) value: U256, + pub(crate) hash: HashOut, +} + +impl FromFields for NeighborInfo { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= NeighborInfoTarget::NUM_TARGETS); + Self { + is_found: t[0].try_into_bool().unwrap(), + is_in_path: t[1].try_into_bool().unwrap(), + value: U256::from_fields(&t[2..2 + NUM_LIMBS]), + hash: HashOut::from_vec(t[2 + NUM_LIMBS..NeighborInfoTarget::NUM_TARGETS].to_vec()), + } + } +} + +impl ToFields for NeighborInfo { + fn to_fields(&self) -> Vec { + [F::from_bool(self.is_found), F::from_bool(self.is_in_path)] + .into_iter() + .chain(self.value.to_fields()) + .chain(self.hash.to_fields()) + .collect() + } +} + +impl NeighborInfo { + // Initialize `Self` for the predecessor/successor of a node. `value` + // must be the value of the predecessor/successor, while `hash` must + // be its hash. If `hash` is `None`, it is assumed that the + // predecessor/successor is not located in the path of the node + pub(crate) fn new(value: U256, hash: Option>) -> Self { + Self { + is_found: true, + is_in_path: hash.is_some(), + value, + hash: hash.unwrap_or(*empty_poseidon_hash()), + } + } + /// Generate at random data about the successor/predecessor of a node. The generated + /// predecessor/successor must have the `value` provided as input; + /// the existence of the generated predecessor/successor depends on the `is_found` input: + /// - if `is_found` is `None`, then the existence of the generated predecessor/successor + /// is chosen at random + /// - otherwise, the generated predecessor/successor will be marked as found if and only if + /// the flag wrapped by `is_found` is `true` + pub(crate) fn sample(rng: &mut R, value: U256, is_found: Option) -> Self { + NeighborInfo { + is_found: is_found.unwrap_or(rng.gen()), + is_in_path: rng.gen(), + value, + hash: gen_random_field_hash(), + } + } +} + #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::array; - use mp2_common::{types::HashOutput, utils::ToTargets, C, D, F}; + use alloy::primitives::U256; + use mp2_common::{ + poseidon::empty_poseidon_hash, + types::HashOutput, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::{FromFields, FromTargets, ToTargets}, + C, D, F, + }; use mp2_test::{ circuit::{run_circuit, UserCircuit}, utils::{gen_random_field_hash, gen_random_u256}, }; use plonky2::{ field::types::Sample, - hash::hash_types::HashOutTarget, + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, iop::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, - plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, + plonk::{ + circuit_builder::CircuitBuilder, config::GenericHashOut, proof::ProofWithPublicInputs, + }, }; use rand::thread_rng; - use crate::query::aggregation::{ChildPosition, NodeInfo}; + use crate::query::utils::{ChildPosition, NodeInfo}; - use super::{MerklePathGadget, MerklePathTargetInputs}; + use super::{ + MerklePathGadget, MerklePathTargetInputs, MerklePathWithNeighborsGadget, + MerklePathWithNeighborsTargetInputs, NeighborInfo, NeighborInfoTarget, + }; #[derive(Clone, Debug)] struct TestMerklePathGadget @@ -318,7 +793,7 @@ mod tests { [(); MAX_DEPTH - 1]:, { merkle_path_inputs: MerklePathGadget, - start_node: NodeInfo, + end_node: NodeInfo, index_id: F, } @@ -330,72 +805,179 @@ mod tests { fn build(c: &mut CircuitBuilder) -> Self::Wires { let index_id = c.add_virtual_target(); - let start_node = c.add_virtual_hash(); - let merkle_path_wires = MerklePathGadget::build(c, start_node, index_id); + let end_node = c.add_virtual_hash(); + let merkle_path_wires = MerklePathGadget::build(c, end_node, index_id); c.register_public_inputs(&merkle_path_wires.root.to_targets()); - (merkle_path_wires.inputs, start_node, index_id) + (merkle_path_wires.inputs, end_node, index_id) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { self.merkle_path_inputs.assign(pw, &wires.0); - pw.set_hash_target(wires.1, self.start_node.compute_node_hash(self.index_id)); + pw.set_hash_target(wires.1, self.end_node.compute_node_hash(self.index_id)); pw.set_target(wires.2, self.index_id); } } - #[test] - fn test_merkle_path() { - // Test a Merkle-path on the following Merkle-tree - // A - // B C - // D G - // E F + impl NeighborInfo { + // Initialize `Self` for a node with no predecessor + pub(crate) fn new_dummy_predecessor() -> Self { + Self { + is_found: false, + is_in_path: true, // the circuit still looks at the predecessor in the path + value: U256::ZERO, + hash: *empty_poseidon_hash(), + } + } - // first, build the Merkle-tree + // Initialize `Self` for a node with no successor + pub(crate) fn new_dummy_successor() -> Self { + Self { + is_found: false, + is_in_path: true, // the circuit still looks at the successor in the path + value: U256::MAX, + hash: *empty_poseidon_hash(), + } + } + } + + #[derive(Clone, Debug)] + struct TestMerklePathWithNeighborsGadget + where + [(); MAX_DEPTH - 1]:, + { + merkle_path_inputs: MerklePathWithNeighborsGadget, + end_node: NodeInfo, + index_id: F, + } + + impl UserCircuit for TestMerklePathWithNeighborsGadget + where + [(); MAX_DEPTH - 1]:, + { + type Wires = ( + MerklePathWithNeighborsTargetInputs, + HashOutTarget, + UInt256Target, + Target, + ); + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let index_id = c.add_virtual_target(); + let end_node_tree_hash = c.add_virtual_hash(); + let end_node_value = c.add_virtual_u256_unsafe(); + let merkle_path_wires = MerklePathWithNeighborsGadget::build( + c, + end_node_value.clone(), + end_node_tree_hash, + index_id, + ); + + c.register_public_inputs(&merkle_path_wires.root.to_targets()); + c.register_public_inputs(&merkle_path_wires.end_node_hash.to_targets()); + c.register_public_inputs(&merkle_path_wires.predecessor_info.to_targets()); + c.register_public_inputs(&merkle_path_wires.successor_info.to_targets()); + + ( + merkle_path_wires.inputs, + end_node_tree_hash, + end_node_value, + index_id, + ) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.merkle_path_inputs.assign(pw, &wires.0); + pw.set_hash_target(wires.1, self.end_node.embedded_tree_hash); + pw.set_u256_target(&wires.2, self.end_node.value); + pw.set_target(wires.3, self.index_id); + } + } + + // method to build a `NodeInfo` for a node from the provided inputs + pub(crate) fn build_node( + left_child: Option<&NodeInfo>, + right_child: Option<&NodeInfo>, + node_value: U256, + embedded_tree_hash: HashOutput, + index_id: F, + ) -> NodeInfo { + let node_min = if let Some(node) = &left_child { + node.min + } else { + node_value + }; + let node_max = if let Some(node) = &right_child { + node.max + } else { + node_value + }; + let left_child = left_child + .map(|node| HashOutput::try_from(node.compute_node_hash(index_id).to_bytes()).unwrap()); + let right_child = right_child + .map(|node| HashOutput::try_from(node.compute_node_hash(index_id).to_bytes()).unwrap()); + NodeInfo::new( + &embedded_tree_hash, + left_child.as_ref(), + right_child.as_ref(), + node_value, + node_min, + node_max, + ) + } + + /// Build the following Merkle-tree to be employed in tests, using + /// the `index_id` provided as input to compute the hash of the nodes + /// A + /// B C + /// D G + /// E F + pub(crate) fn generate_test_tree( + index_id: F, + value_range: Option<(U256, U256)>, + ) -> [NodeInfo; 7] { let rng = &mut thread_rng(); - let index_id = F::rand(); // closure to generate a random node of the tree from the 2 children, if any - let mut random_node = - |left_child: Option<&HashOutput>, right_child: Option<&HashOutput>| -> NodeInfo { - let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let [node_min, node_max, node_value] = array::from_fn(|_| gen_random_u256(rng)); - NodeInfo::new( - &embedded_tree_hash, - left_child, - right_child, - node_value, - node_min, - node_max, - ) - }; + let random_node = |left_child: Option<&NodeInfo>, + right_child: Option<&NodeInfo>, + node_value: U256| + -> NodeInfo { + let embedded_tree_hash = + HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); + build_node( + left_child, + right_child, + node_value, + embedded_tree_hash, + index_id, + ) + }; + let mut values: [U256; 7] = array::from_fn(|_| gen_random_u256(rng)); + if let Some((min_range, max_range)) = value_range { + // trim random values to the range specified as input + values.iter_mut().for_each(|value| { + *value = min_range + *value % (max_range - min_range + U256::from(1)) + }); + } + values.sort(); + let node_e = random_node(None, None, values[0]); // it's a leaf node, so no children + let node_f = random_node(None, None, values[2]); + let node_g = random_node(None, None, values[6]); + let node_d = random_node(Some(&node_e), Some(&node_f), values[1]); + let node_b = random_node(Some(&node_d), None, values[3]); + let node_c = random_node(None, Some(&node_g), values[5]); + let node_a = random_node(Some(&node_b), Some(&node_c), values[4]); + [node_a, node_b, node_c, node_d, node_e, node_f, node_g] + } - let node_e = random_node(None, None); // it's a leaf node, so no children - let node_f = random_node(None, None); - let node_g = random_node(None, None); - let node_e_hash = - HashOutput::try_from(node_e.compute_node_hash(index_id).to_bytes()).unwrap(); - let node_d = random_node( - Some(&node_e_hash), - Some(&HashOutput::try_from(node_f.compute_node_hash(index_id).to_bytes()).unwrap()), - ); - let node_b = random_node( - Some(&HashOutput::try_from(node_d.compute_node_hash(index_id).to_bytes()).unwrap()), - None, - ); - let node_c = random_node( - None, - Some(&HashOutput::try_from(node_g.compute_node_hash(index_id).to_bytes()).unwrap()), - ); - let node_b_hash = - HashOutput::try_from(node_b.compute_node_hash(index_id).to_bytes()).unwrap(); - let node_c_hash = - HashOutput::try_from(node_c.compute_node_hash(index_id).to_bytes()).unwrap(); - let node_a = random_node(Some(&node_b_hash), Some(&node_c_hash)); + #[test] + fn test_merkle_path() { + // first, build the Merkle-tree + let index_id = F::rand(); + let [node_a, node_b, node_c, node_d, node_e, node_f, node_g] = + generate_test_tree(index_id, None); let root = node_a.compute_node_hash(index_id); - // verify Merkle-path related to leaf F const MAX_DEPTH: usize = 10; let path = vec![ @@ -403,12 +985,14 @@ mod tests { (node_b, ChildPosition::Left), (node_a, ChildPosition::Left), ]; + let node_e_hash = HashOutput::from(node_e.compute_node_hash(index_id)); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(index_id)); let siblings = vec![Some(node_e_hash), None, Some(node_c_hash)]; let merkle_path_inputs = MerklePathGadget::::new(&path, &siblings).unwrap(); let circuit = TestMerklePathGadget:: { merkle_path_inputs, - start_node: node_f, + end_node: node_f, index_id, }; @@ -421,11 +1005,12 @@ mod tests { (node_c, ChildPosition::Right), (node_a, ChildPosition::Right), ]; + let node_b_hash = HashOutput::from(node_b.compute_node_hash(index_id)); let siblings = vec![None, Some(node_b_hash)]; let merkle_path_inputs = MerklePathGadget::::new(&path, &siblings).unwrap(); let circuit = TestMerklePathGadget:: { merkle_path_inputs, - start_node: node_g, + end_node: node_g, index_id, }; @@ -439,7 +1024,7 @@ mod tests { let merkle_path_inputs = MerklePathGadget::::new(&path, &siblings).unwrap(); let circuit = TestMerklePathGadget:: { merkle_path_inputs, - start_node: node_d, + end_node: node_d, index_id, }; @@ -447,4 +1032,222 @@ mod tests { // check that the re-computed root is correct assert_eq!(proof.public_inputs, root.to_vec()); } + + #[test] + fn test_merkle_path_with_neighbors() { + // first, build the Merkle-tree + let index_id = F::rand(); + let [node_a, node_b, node_c, node_d, node_e, node_f, node_g] = + generate_test_tree(index_id, None); + let root = node_a.compute_node_hash(index_id); + // verify Merkle-path related to leaf F + const MAX_DEPTH: usize = 10; + let path = vec![ + (node_d, ChildPosition::Right), // we start from the ancestor of the start node of the path + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let node_e_hash = HashOutput::from(node_e.compute_node_hash(index_id)); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(index_id)); + let siblings = vec![Some(node_e_hash), None, Some(node_c_hash)]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_f, + [None, None], // it's a leaf node + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_f, + index_id, + }; + + let proof = run_circuit(circuit); + + // closure to check correctness of public inputs + let check_public_inputs = |proof: ProofWithPublicInputs, + node: &NodeInfo, + node_name: &str, + predecessor_info, + successor_info| { + // check that the re-computed root is correct + assert_eq!( + proof.public_inputs[..NUM_HASH_OUT_ELTS], + root.to_vec(), + "failed for node {node_name}" + ); + // check that the hash of node_F is correct + let node_hash = node.compute_node_hash(index_id); + assert_eq!( + proof.public_inputs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS], + node_hash.elements, + "failed for node {node_name}" + ); + // check predecessor info extracted in the circuit + assert_eq!( + NeighborInfo::from_fields(&proof.public_inputs[2 * NUM_HASH_OUT_ELTS..]), + predecessor_info, + "failed for node {node_name}" + ); + // check successor info extracted in the circuit + assert_eq!( + NeighborInfo::from_fields( + &proof.public_inputs[2 * NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..] + ), + successor_info, + "failed for node {node_name}" + ); + }; + // build predecessor and successor info for node_F + // predecessor should be node_D + let node_d_hash = node_d.compute_node_hash(index_id); + let predecessor_info = NeighborInfo::new(node_d.value, Some(node_d_hash)); + // successor should be node_B + let node_b_hash = node_b.compute_node_hash(index_id); + let successor_info = NeighborInfo::new(node_b.value, Some(node_b_hash)); + check_public_inputs(proof, &node_f, "node F", predecessor_info, successor_info); + + // verify Merkle-path related to leaf E + let path = vec![ + (node_d, ChildPosition::Left), // we start from the ancestor of the start node of the path + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let node_f_hash = HashOutput::from(node_f.compute_node_hash(index_id)); + let siblings = vec![Some(node_f_hash), None, Some(node_c_hash)]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_e, + [None, None], // it's a leaf node + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_e, + index_id, + }; + + let proof = run_circuit(circuit); + + // build predecessor and successor info for node_E + // There should be no predecessor + let predecessor_info = NeighborInfo::new_dummy_predecessor(); + // successor should be node_D + let successor_info = NeighborInfo::new(node_d.value, Some(node_d_hash)); + check_public_inputs(proof, &node_e, "node E", predecessor_info, successor_info); + + // verify Merkle-path related to node D + let path = vec![(node_b, ChildPosition::Left), (node_a, ChildPosition::Left)]; + let siblings = vec![None, Some(node_c_hash)]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_d, + [Some(node_e), Some(node_f)], + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_d, + index_id, + }; + + let proof = run_circuit(circuit); + + // build predecessor and successor info for node_D + // predecessor should be node_E, but it's not in the path + let predecessor_info = NeighborInfo::new(node_e.value, None); + // successor should be node_F, but it's not in the path + let successor_info = NeighborInfo::new(node_f.value, None); + check_public_inputs(proof, &node_d, "node D", predecessor_info, successor_info); + + // verify Merkle-path related to node B + let path = vec![(node_a, ChildPosition::Left)]; + let siblings = vec![Some(node_c_hash)]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_b, + [Some(node_d), None], // Node D is the left child + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_b, + index_id, + }; + + let proof = run_circuit(circuit); + + // build predecessor and successor info for node_B + // predecessor should be node_F, but it's not in the path + let predecessor_info = NeighborInfo::new(node_f.value, None); + // successor should be node_A + let successor_info = NeighborInfo::new(node_a.value, Some(root)); + check_public_inputs(proof, &node_b, "node B", predecessor_info, successor_info); + + // verify Merkle-path related to leaf G + let path = vec![ + (node_c, ChildPosition::Right), + (node_a, ChildPosition::Right), + ]; + let siblings = vec![None, Some(HashOutput::from(node_b_hash))]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_g, + [None, None], // it's a leaf node + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_g, + index_id, + }; + + let proof = run_circuit(circuit); + + // build predecessor and successor info for node_G + // predecessor should be node_C + let predecessor_info = NeighborInfo::new( + node_c.value, + Some(HashOut::from_bytes((&node_c_hash).into())), + ); + // There should be no successor + let successor_info = NeighborInfo::new_dummy_successor(); + check_public_inputs(proof, &node_g, "node G", predecessor_info, successor_info); + + // verify Merkle-path related to root node A + let path = vec![]; + let siblings = vec![]; + let merkle_path_inputs = MerklePathWithNeighborsGadget::::new( + &path, + &siblings, + &node_a, + [Some(node_b), Some(node_c)], // it's a leaf node + ) + .unwrap(); + + let circuit = TestMerklePathWithNeighborsGadget:: { + merkle_path_inputs, + end_node: node_a, + index_id, + }; + + let proof = run_circuit(circuit); + + // build predecessor and successor info for node_A + // predecessor should be node_B, but it's not in the path + let predecessor_info = NeighborInfo::new(node_b.value, None); + // successor should be node_C, but it's not in the path + let successor_info = NeighborInfo::new(node_c.value, None); + check_public_inputs(proof, &node_a, "node A", predecessor_info, successor_info); + } } diff --git a/verifiable-db/src/query/mod.rs b/verifiable-db/src/query/mod.rs index b07c5e45b..e20a6987f 100644 --- a/verifiable-db/src/query/mod.rs +++ b/verifiable-db/src/query/mod.rs @@ -1,13 +1,16 @@ -use mp2_common::F; -use public_inputs::PublicInputs; +use plonky2::iop::target::Target; +use public_inputs::PublicInputsQueryCircuits; -pub mod aggregation; pub mod api; +pub(crate) mod circuits; pub mod computational_hash_ids; pub mod merkle_path; +pub(crate) mod output_computation; pub mod public_inputs; +pub(crate) mod row_chunk_gadgets; pub mod universal_circuit; +pub mod utils; pub const fn pi_len() -> usize { - PublicInputs::::total_len() + PublicInputsQueryCircuits::::total_len() } diff --git a/verifiable-db/src/query/aggregation/output_computation.rs b/verifiable-db/src/query/output_computation.rs similarity index 57% rename from verifiable-db/src/query/aggregation/output_computation.rs rename to verifiable-db/src/query/output_computation.rs index ca5fcd159..70b5f232c 100644 --- a/verifiable-db/src/query/aggregation/output_computation.rs +++ b/verifiable-db/src/query/output_computation.rs @@ -2,7 +2,7 @@ use crate::query::{ computational_hash_ids::{AggregationOperation, Identifiers}, - public_inputs::PublicInputs, + universal_circuit::universal_query_gadget::{CurveOrU256Target, OutputValuesTarget}, }; use alloy::primitives::U256; use mp2_common::{ @@ -14,7 +14,7 @@ use mp2_common::{ }; use plonky2::iop::target::Target; use plonky2_crypto::u32::arithmetic_u32::CircuitBuilderU32; -use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; /// Compute the dummy targets for each of the `S` values to be returned as output. pub(crate) fn compute_dummy_output_targets( @@ -43,9 +43,7 @@ pub(crate) fn compute_dummy_output_targets( is_op_id, curve_zero, // Pad the current output to `CURVE_TARGET_LEN` for the first item. - CurveTarget::from_targets(&PublicInputs::<_, S>::pad_slice_to_curve_len( - &output, - )), + CurveOrU256Target::from_targets(&output).as_curve_target(), ) .to_targets(); } @@ -56,112 +54,115 @@ pub(crate) fn compute_dummy_output_targets( outputs } -/// Compute the node output item at the specified index by the proofs, -/// and return the output item with the overflow number. -pub(crate) fn compute_output_item( - b: &mut CBuilder, - i: usize, - proofs: &[&PublicInputs], -) -> (Vec, Target) +impl OutputValuesTarget where - [(); S - 1]:, + [(); MAX_NUM_RESULTS - 1]:, { - let zero = b.zero(); - let u32_zero = b.zero_u32(); - let u256_zero = b.zero_u256(); + /// Aggregate the i-th output values in `outputs` according to the aggregation operation specified in + /// `op`. It returns the targets representing the aggregated output and a target yielding the number + /// of overflows occurred during aggregation + pub(crate) fn aggregate_outputs( + b: &mut CBuilder, + outputs: &[Self], + op: Target, + i: usize, + ) -> (Vec, Target) { + let zero = b.zero(); + let u32_zero = b.zero_u32(); + let u256_zero = b.zero_u256(); + + let out0 = &outputs[0]; + + let [op_id, op_min, op_max, op_sum, op_avg] = [ + AggregationOperation::IdOp, + AggregationOperation::MinOp, + AggregationOperation::MaxOp, + AggregationOperation::SumOp, + AggregationOperation::AvgOp, + ] + .map(|op| b.constant(Identifiers::AggregationOperations(op).to_field())); - let proof0 = &proofs[0]; - let op = proof0.operation_ids_target()[i]; - - let [op_id, op_min, op_max, op_sum, op_avg] = [ - AggregationOperation::IdOp, - AggregationOperation::MinOp, - AggregationOperation::MaxOp, - AggregationOperation::SumOp, - AggregationOperation::AvgOp, - ] - .map(|op| b.constant(Identifiers::AggregationOperations(op).to_field())); - - let is_op_id = b.is_equal(op, op_id); - let is_op_min = b.is_equal(op, op_min); - let is_op_max = b.is_equal(op, op_max); - let is_op_sum = b.is_equal(op, op_sum); - let is_op_avg = b.is_equal(op, op_avg); - - // Check that the all proofs are employing the same aggregation operation. - proofs[1..] - .iter() - .for_each(|p| b.connect(p.operation_ids_target()[i], op)); - - // Compute the SUM, MIN and MAX values. - let mut sum_overflow = zero; - let mut sum_value = proof0.value_target_at_index(i); - if i == 0 { - // If it's the first proof and the operation is ID, the value is a curve point, - // which each field may be out of range of an Uint32 (to combine an Uint256). - sum_value = b.select_u256(is_op_id, &u256_zero, &sum_value); - } - let mut min_value = sum_value.clone(); - let mut max_value = sum_value.clone(); - for p in proofs[1..].iter() { - // Get the current proof value. - let mut value = p.value_target_at_index(i); + let is_op_id = b.is_equal(op, op_id); + let is_op_min = b.is_equal(op, op_min); + let is_op_max = b.is_equal(op, op_max); + let is_op_sum = b.is_equal(op, op_sum); + let is_op_avg = b.is_equal(op, op_avg); + + // Compute the SUM, MIN and MAX values. + let mut sum_overflow = zero; + let mut sum_value = out0.value_target_at_index(i); if i == 0 { // If it's the first proof and the operation is ID, the value is a curve point, // which each field may be out of range of an Uint32 (to combine an Uint256). - value = b.select_u256(is_op_id, &u256_zero, &value); - }; - - // Compute the SUM value and the overflow. - let (addition, overflow) = b.add_u256(&sum_value, &value); - sum_value = addition; - sum_overflow = b.add(sum_overflow, overflow.0); - - // Compute the MIN and MAX values. - let (_, borrow) = b.sub_u256(&value, &min_value); - let not_less_than = b.is_equal(borrow.0, u32_zero.0); - min_value = b.select_u256(not_less_than, &min_value, &value); - let (_, borrow) = b.sub_u256(&value, &max_value); - let not_less_than = b.is_equal(borrow.0, u32_zero.0); - max_value = b.select_u256(not_less_than, &value, &max_value); - } + sum_value = b.select_u256(is_op_id, &u256_zero, &sum_value); + } + let mut min_value = sum_value.clone(); + let mut max_value = sum_value.clone(); + for p in outputs[1..].iter() { + // Get the current proof value. + let mut value = p.value_target_at_index(i); + if i == 0 { + // If it's the first proof and the operation is ID, the value is a curve point, + // which each field may be out of range of an Uint32 (to combine an Uint256). + value = b.select_u256(is_op_id, &u256_zero, &value); + }; + + // Compute the SUM value and the overflow. + let (addition, overflow) = b.add_u256(&sum_value, &value); + sum_value = addition; + sum_overflow = b.add(sum_overflow, overflow.0); + + // Compute the MIN and MAX values. + let (_, borrow) = b.sub_u256(&value, &min_value); + let not_less_than = b.is_equal(borrow.0, u32_zero.0); + min_value = b.select_u256(not_less_than, &min_value, &value); + let (_, borrow) = b.sub_u256(&value, &max_value); + let not_less_than = b.is_equal(borrow.0, u32_zero.0); + max_value = b.select_u256(not_less_than, &value, &max_value); + } - // Compute the output item. - let output = b.select_u256(is_op_min, &min_value, &sum_value); - let output = b.select_u256(is_op_max, &max_value, &output); - let mut output = output.to_targets(); + // Compute the output item. + let output = b.select_u256(is_op_min, &min_value, &sum_value); + let output = b.select_u256(is_op_max, &max_value, &output); + let mut output = output.to_targets(); - if i == 0 { - // We always accumulate order-agnostic digest of the proofs for the first item. - let points: Vec<_> = proofs - .iter() - .map(|p| p.first_value_as_curve_target()) - .collect(); - let digest = b.add_curve_point(&points); - let a = b.curve_select( - is_op_id, - digest, - // Pad the current output to `CURVE_TARGET_LEN` for the first item. - CurveTarget::from_targets(&PublicInputs::<_, S>::pad_slice_to_curve_len(&output)), - ); - output = a.to_targets(); - } + if i == 0 { + // We always accumulate order-agnostic digest of the proofs for the first item. + let points: Vec<_> = outputs + .iter() + .map(|out| out.first_output.as_curve_target()) + .collect(); + let digest = b.add_curve_point(&points); + let a = b.curve_select( + is_op_id, + digest, + // Pad the current output to `CURVE_TARGET_LEN` for the first item. + CurveOrU256Target::from_targets(&output).as_curve_target(), + ); + output = a.to_targets(); + } - // Set the overflow if the operation is SUM or AVG: - // overflow = op == SUM OR op == AVG ? sum_overflow : 0 - let is_op_sum_or_avg = b.or(is_op_sum, is_op_avg); - let overflow = b.mul(is_op_sum_or_avg.target, sum_overflow); + // Set the overflow if the operation is SUM or AVG: + // overflow = op == SUM OR op == AVG ? sum_overflow : 0 + let is_op_sum_or_avg = b.or(is_op_sum, is_op_avg); + let overflow = b.mul(is_op_sum_or_avg.target, sum_overflow); - (output, overflow) + (output, overflow) + } } #[cfg(test)] pub(crate) mod tests { use super::*; use crate::{ - query::{aggregation::tests::compute_output_item_value, pi_len}, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + query::{ + pi_len, public_inputs::PublicInputsQueryCircuits, + universal_circuit::universal_query_gadget::CurveOrU256, + utils::tests::compute_output_item_value, + }, + test_utils::random_aggregation_operations, }; + use itertools::Itertools; use mp2_common::{types::CURVE_TARGET_LEN, u256::NUM_LIMBS, utils::ToFields, C, D, F}; use mp2_test::circuit::{run_circuit, UserCircuit}; use plonky2::{ @@ -171,6 +172,32 @@ pub(crate) mod tests { use plonky2_ecgfp5::curve::curve::Point; use std::array; + /// Compute the node output item at the specified index by the proofs, + /// and return the output item with the overflow number. + pub(crate) fn compute_output_item( + b: &mut CBuilder, + i: usize, + proofs: &[&PublicInputsQueryCircuits], + ) -> (Vec, Target) + where + [(); S - 1]:, + { + let proof0 = &proofs[0]; + let op = proof0.operation_ids_target()[i]; + + // Check that the all proofs are employing the same aggregation operation. + proofs[1..] + .iter() + .for_each(|p| b.connect(p.operation_ids_target()[i], op)); + + let outputs = proofs + .iter() + .map(|p| OutputValuesTarget::from_targets(p.to_values_raw())) + .collect_vec(); + + OutputValuesTarget::aggregate_outputs(b, &outputs, op, i) + } + /// Compute the dummy values for each of the `S` values to be returned as output. /// It's the test function corresponding to `compute_dummy_output_targets`. pub(crate) fn compute_dummy_output_values(ops: &[F; S]) -> Vec { @@ -192,7 +219,7 @@ pub(crate) mod tests { Point::NEUTRAL.to_fields() } else { // Pad the current output to `CURVE_TARGET_LEN` for the first item. - PublicInputs::<_, S>::pad_slice_to_curve_len(&output) + CurveOrU256::from_slice(&output).to_vec() }; } outputs.append(&mut output); @@ -244,7 +271,8 @@ pub(crate) mod tests { }); // Build the public inputs. - let pis = [0; PROOF_NUM].map(|i| PublicInputs::::from_slice(&proofs[i])); + let pis = [0; PROOF_NUM] + .map(|i| PublicInputsQueryCircuits::::from_slice(&proofs[i])); let pis = [0; PROOF_NUM].map(|i| &pis[i]); // Check if the outputs as expected. @@ -281,7 +309,8 @@ pub(crate) mod tests { [(); S - 1]:, { fn new(proofs: [Vec; PROOF_NUM]) -> Self { - let pis = [0; PROOF_NUM].map(|i| PublicInputs::::from_slice(&proofs[i])); + let pis = + [0; PROOF_NUM].map(|i| PublicInputsQueryCircuits::::from_slice(&proofs[i])); let pis = [0; PROOF_NUM].map(|i| &pis[i]); let exp_outputs = array::from_fn(|i| { @@ -306,7 +335,7 @@ pub(crate) mod tests { let ops: [_; S] = random_aggregation_operations(); // Build the input proofs. - let inputs = random_aggregation_public_inputs(&ops); + let inputs = PublicInputsQueryCircuits::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); @@ -327,7 +356,7 @@ pub(crate) mod tests { ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); // Build the input proofs. - let inputs = random_aggregation_public_inputs(&ops); + let inputs = PublicInputsQueryCircuits::::sample_from_ops(&ops); // Construct the test circuit. let test_circuit = TestOutputComputationCircuit::::new(inputs); diff --git a/verifiable-db/src/query/public_inputs.rs b/verifiable-db/src/query/public_inputs.rs index f43204b18..3f37294c3 100644 --- a/verifiable-db/src/query/public_inputs.rs +++ b/verifiable-db/src/query/public_inputs.rs @@ -4,17 +4,26 @@ use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ public_inputs::{PublicInputCommon, PublicInputRange}, - types::{CBuilder, CURVE_TARGET_LEN}, - u256::{UInt256Target, NUM_LIMBS}, + types::CBuilder, + u256::UInt256Target, utils::{FromFields, FromTargets, TryIntoBool}, F, }; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOut, HashOutTarget}, iop::target::{BoolTarget, Target}, }; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; +use crate::query::{ + output_computation::compute_dummy_output_targets, + universal_circuit::universal_query_gadget::{ + CurveOrU256Target, OutputValues, OutputValuesTarget, UniversalQueryOutputWires, + }, +}; + +use super::row_chunk_gadgets::{BoundaryRowDataTarget, RowChunkDataTarget}; + /// Query circuits public inputs pub enum QueryPublicInputs { /// `H`: Hash of the tree @@ -29,22 +38,56 @@ pub enum QueryPublicInputs { /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` /// (like "SUM", "MIN", "MAX", "COUNT" operations) OpIds, - /// `I` : `u256` value of the indexed column for the given node (meaningful only for rows tree nodes) - IndexValue, - /// `min` : `u256` Minimum value of the indexed column among all the records stored in the subtree rooted - /// in the current node; values of secondary indexed column are employed for rows tree nodes, - /// while values of primary indexed column are employed for index tree nodes - MinValue, - /// `max`` : Maximum value of the indexed column among all the records stored in the subtree rooted - /// in the current node; values of secondary indexed column are employed for rows tree nodes, - /// while values of primary indexed column are employed for index tree nodes - MaxValue, - /// `index_ids`` : `[2]F` Identifiers of indexed columns - IndexIds, - /// `MIN_I`: `u256` Lower bound of the range of indexed column values specified in the query - MinQuery, - /// `MAX_I`: `u256` Upper bound of the range of indexed column values specified in the query - MaxQuery, + /// Data associated to the left boundary row of the row chunk being proven + LeftBoundaryRow, + /// Data associated to the right boundary row of the row chunk being proven + RightBoundaryRow, + /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query + MinPrimary, + /// `MAX_primary`: `u256` Upper bound of the range of primary indexed column values specified in the query + MaxPrimary, + /// `MIN_secondary`: `u256` Lower bound of the range of secondary indexed column values specified in the query + MinSecondary, + /// `MAX_secondary`: `u256` Upper bound of the range of secondary indexed column values specified in the query + MaxSecondary, + /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic + Overflow, + /// `C`: computational hash + ComputationalHash, + /// `H_p` : placeholder hash + PlaceholderHash, +} + +/// Public inputs for the universal query circuit. They are mostly the same as `QueryPublicInputs`, the only +/// difference is that the query range on secondary index is replaced by the value of the indexed columns for +/// the columns being proven +pub enum QueryPublicInputsUniversalCircuit { + /// `H`: Hash of the tree + TreeHash, + /// `V`: Set of `S` values representing the cumulative results of the query, where`S` is a parameter + /// specifying the maximum number of cumulative results we support; + /// the first value could be either a `u256` or a `CurveTarget`, depending on the query, and so we always + /// represent this value with `CURVE_TARGET_LEN` elements; all the other `S-1` values are always `u256` + OutputValues, + /// `count`: `F` Number of matching records in the query + NumMatching, + /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` + /// (like "SUM", "MIN", "MAX", "COUNT" operations) + OpIds, + /// Data associated to the left boundary row of the row chunk being proven; it is dummy in case of universal query + /// circuit, it is just empoyed to re-use the same public inputs + LeftBoundaryRow, + /// Data associated to the right boundary row of the row chunk being proven; it is dummy in case of universal query + /// circuit, it is just empoyed to re-use the same public inputs + RightBoundaryRow, + /// `MIN_primary`: `u256` Lower bound of the range of primary indexed column values specified in the query + MinPrimary, + /// `MAX_primary`: `u256` Upper bound of the range of primary indexed column values specified in the query + MaxPrimary, + /// Value of secondary indexed column for the row being proven + SecondaryIndexValue, + /// Value of primary indexed column for the row being proven + PrimaryIndexValue, /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic Overflow, /// `C`: computational hash @@ -53,18 +96,62 @@ pub enum QueryPublicInputs { PlaceholderHash, } +impl From for QueryPublicInputs { + fn from(value: QueryPublicInputsUniversalCircuit) -> Self { + match value { + QueryPublicInputsUniversalCircuit::TreeHash => QueryPublicInputs::TreeHash, + QueryPublicInputsUniversalCircuit::OutputValues => QueryPublicInputs::OutputValues, + QueryPublicInputsUniversalCircuit::NumMatching => QueryPublicInputs::NumMatching, + QueryPublicInputsUniversalCircuit::OpIds => QueryPublicInputs::NumMatching, + QueryPublicInputsUniversalCircuit::LeftBoundaryRow => { + QueryPublicInputs::LeftBoundaryRow + } + QueryPublicInputsUniversalCircuit::RightBoundaryRow => { + QueryPublicInputs::RightBoundaryRow + } + QueryPublicInputsUniversalCircuit::MinPrimary => QueryPublicInputs::MinPrimary, + QueryPublicInputsUniversalCircuit::MaxPrimary => QueryPublicInputs::MaxPrimary, + QueryPublicInputsUniversalCircuit::SecondaryIndexValue => { + QueryPublicInputs::MinSecondary + } + QueryPublicInputsUniversalCircuit::PrimaryIndexValue => QueryPublicInputs::MaxSecondary, + QueryPublicInputsUniversalCircuit::Overflow => QueryPublicInputs::Overflow, + QueryPublicInputsUniversalCircuit::ComputationalHash => { + QueryPublicInputs::ComputationalHash + } + QueryPublicInputsUniversalCircuit::PlaceholderHash => { + QueryPublicInputs::PlaceholderHash + } + } + } +} +/// Public inputs for query circuits +pub type PublicInputsQueryCircuits<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, false>; +/// Public inputs only for universal query circuit +pub type PublicInputsUniversalCircuit<'a, T, const S: usize> = PublicInputsFactory<'a, T, S, true>; + +/// This is the data structure employed for both public inputs of generic query circuits +/// and for public inputs of the universal circuit. Since the 2 public inputs are the +/// same, except for the semantic of 2 U256 elements, they can be represented by the +/// same data structure. The `UNIVERSAL_CIRCUIT` const generic is employed to +/// define 2 type aliases: 1 for public inputs of generic query circuits, and 1 for +/// public inputs of universal query circuit. The methods being common between the +/// 2 public inputs are implemented for this data structure, while the methods that +/// are specific to each public input type are implemented for the corresponding alias. +/// In this way, the methods implemented for the type alias define the correct semantics +/// of each of the items in both types of public inputs. #[derive(Clone, Debug)] -pub struct PublicInputs<'a, T, const S: usize> { +pub struct PublicInputsFactory<'a, T, const S: usize, const UNIVERSAL_CIRCUIT: bool> { h: &'a [T], v: &'a [T], ops: &'a [T], count: &'a T, - i: &'a [T], - min: &'a [T], - max: &'a [T], - ids: &'a [T], - min_q: &'a [T], - max_q: &'a [T], + left_row: &'a [T], + right_row: &'a [T], + min_p: &'a [T], + max_p: &'a [T], + min_s: &'a [T], + max_s: &'a [T], overflow: &'a T, ch: &'a [T], ph: &'a [T], @@ -72,53 +159,55 @@ pub struct PublicInputs<'a, T, const S: usize> { const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; -impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { +impl<'a, T: Clone, const S: usize, const UNIVERSAL_CIRCUIT: bool> + PublicInputsFactory<'a, T, S, UNIVERSAL_CIRCUIT> +{ const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ - Self::to_range(QueryPublicInputs::TreeHash), - Self::to_range(QueryPublicInputs::OutputValues), - Self::to_range(QueryPublicInputs::NumMatching), - Self::to_range(QueryPublicInputs::OpIds), - Self::to_range(QueryPublicInputs::IndexValue), - Self::to_range(QueryPublicInputs::MinValue), - Self::to_range(QueryPublicInputs::MaxValue), - Self::to_range(QueryPublicInputs::IndexIds), - Self::to_range(QueryPublicInputs::MinQuery), - Self::to_range(QueryPublicInputs::MaxQuery), - Self::to_range(QueryPublicInputs::Overflow), - Self::to_range(QueryPublicInputs::ComputationalHash), - Self::to_range(QueryPublicInputs::PlaceholderHash), + Self::to_range_internal(QueryPublicInputs::TreeHash), + Self::to_range_internal(QueryPublicInputs::OutputValues), + Self::to_range_internal(QueryPublicInputs::NumMatching), + Self::to_range_internal(QueryPublicInputs::OpIds), + Self::to_range_internal(QueryPublicInputs::LeftBoundaryRow), + Self::to_range_internal(QueryPublicInputs::RightBoundaryRow), + Self::to_range_internal(QueryPublicInputs::MinPrimary), + Self::to_range_internal(QueryPublicInputs::MaxPrimary), + Self::to_range_internal(QueryPublicInputs::MinSecondary), + Self::to_range_internal(QueryPublicInputs::MaxSecondary), + Self::to_range_internal(QueryPublicInputs::Overflow), + Self::to_range_internal(QueryPublicInputs::ComputationalHash), + Self::to_range_internal(QueryPublicInputs::PlaceholderHash), ]; const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ // Tree hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, // Output values - CURVE_TARGET_LEN + NUM_LIMBS * (S - 1), + CurveTarget::NUM_TARGETS + UInt256Target::NUM_TARGETS * (S - 1), // Number of matching records 1, // Operation identifiers S, - // Index column value - NUM_LIMBS, - // Minimum indexed column value - NUM_LIMBS, - // Maximum indexed column value - NUM_LIMBS, - // Indexed column IDs - 2, - // Lower bound for indexed column specified in query - NUM_LIMBS, - // Upper bound for indexed column specified in query - NUM_LIMBS, + // Left boundary row + BoundaryRowDataTarget::NUM_TARGETS, + // Right boundary row + BoundaryRowDataTarget::NUM_TARGETS, + // Min primary index + UInt256Target::NUM_TARGETS, + // Max primary index + UInt256Target::NUM_TARGETS, + // Min secondary index + UInt256Target::NUM_TARGETS, + // Max secondary index + UInt256Target::NUM_TARGETS, // Overflow flag 1, // Computational hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, // Placeholder hash - NUM_HASH_OUT_ELTS, + HashOutTarget::NUM_TARGETS, ]; - pub const fn to_range(query_pi: QueryPublicInputs) -> PublicInputRange { + const fn to_range_internal(query_pi: QueryPublicInputs) -> PublicInputRange { let mut i = 0; let mut offset = 0; let pi_pos = query_pi as usize; @@ -129,8 +218,12 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { offset..offset + Self::SIZES[pi_pos] } + pub fn to_range>(query_pi: Q) -> PublicInputRange { + Self::to_range_internal(query_pi.into()) + } + pub(crate) const fn total_len() -> usize { - Self::to_range(QueryPublicInputs::PlaceholderHash).end + Self::to_range_internal(QueryPublicInputs::PlaceholderHash).end } pub(crate) fn to_hash_raw(&self) -> &[T] { @@ -149,28 +242,28 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { self.ops } - pub(crate) fn to_index_value_raw(&self) -> &[T] { - self.i + pub(crate) fn to_left_row_raw(&self) -> &[T] { + self.left_row } - pub(crate) fn to_min_value_raw(&self) -> &[T] { - self.min + pub(crate) fn to_right_row_raw(&self) -> &[T] { + self.right_row } - pub(crate) fn to_max_value_raw(&self) -> &[T] { - self.max + pub(crate) fn to_min_primary_raw(&self) -> &[T] { + self.min_p } - pub(crate) fn to_index_ids_raw(&self) -> &[T] { - self.ids + pub(crate) fn to_max_primary_raw(&self) -> &[T] { + self.max_p } - pub(crate) fn to_min_query_raw(&self) -> &[T] { - self.min_q + pub(crate) fn to_min_secondary_raw(&self) -> &[T] { + self.min_s } - pub(crate) fn to_max_query_raw(&self) -> &[T] { - self.max_q + pub(crate) fn to_max_secondary_raw(&self) -> &[T] { + self.max_s } pub(crate) fn to_overflow_raw(&self) -> &T { @@ -185,22 +278,6 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { self.ph } - /// Pad the input slice `t` to `CURVE_TARGET_LEN`; this method should be employed - /// to ensure that the slice representing the first output value has always the - /// expected length - pub(crate) fn pad_slice_to_curve_len(t: &[T]) -> Vec { - let mut result = t.to_vec(); - assert!(CURVE_TARGET_LEN >= result.len()); - let diff = CURVE_TARGET_LEN - result.len(); - result.extend_from_slice(vec![result[0].clone(); diff].as_slice()); - result - } - - /// Remove the padding introduced by `pad_slice_to_curve_len` - pub(crate) fn truncate_slice_to_u256_raw(t: &[T]) -> &[T] { - &t[..NUM_LIMBS] - } - pub fn from_slice(input: &'a [T]) -> Self { assert!( input.len() >= Self::total_len(), @@ -212,29 +289,30 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { v: &input[Self::PI_RANGES[1].clone()], count: &input[Self::PI_RANGES[2].clone()][0], ops: &input[Self::PI_RANGES[3].clone()], - i: &input[Self::PI_RANGES[4].clone()], - min: &input[Self::PI_RANGES[5].clone()], - max: &input[Self::PI_RANGES[6].clone()], - ids: &input[Self::PI_RANGES[7].clone()], - min_q: &input[Self::PI_RANGES[8].clone()], - max_q: &input[Self::PI_RANGES[9].clone()], + left_row: &input[Self::PI_RANGES[4].clone()], + right_row: &input[Self::PI_RANGES[5].clone()], + min_p: &input[Self::PI_RANGES[6].clone()], + max_p: &input[Self::PI_RANGES[7].clone()], + min_s: &input[Self::PI_RANGES[8].clone()], + max_s: &input[Self::PI_RANGES[9].clone()], overflow: &input[Self::PI_RANGES[10].clone()][0], ch: &input[Self::PI_RANGES[11].clone()], ph: &input[Self::PI_RANGES[12].clone()], } } + #[allow(clippy::too_many_arguments)] pub fn new( h: &'a [T], v: &'a [T], count: &'a [T], ops: &'a [T], - i: &'a [T], - min: &'a [T], - max: &'a [T], - ids: &'a [T], - min_q: &'a [T], - max_q: &'a [T], + left_row: &'a [T], + right_row: &'a [T], + min_p: &'a [T], + max_p: &'a [T], + min_s: &'a [T], + max_s: &'a [T], overflow: &'a [T], ch: &'a [T], ph: &'a [T], @@ -244,12 +322,12 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { v, count: &count[0], ops, - i, - min, - max, - ids, - min_q, - max_q, + left_row, + right_row, + min_p, + max_p, + min_s, + max_s, overflow: &overflow[0], ch, ph, @@ -262,12 +340,12 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { .chain(self.v.iter()) .chain(once(self.count)) .chain(self.ops.iter()) - .chain(self.i.iter()) - .chain(self.min.iter()) - .chain(self.max.iter()) - .chain(self.ids.iter()) - .chain(self.min_q.iter()) - .chain(self.max_q.iter()) + .chain(self.left_row.iter()) + .chain(self.right_row.iter()) + .chain(self.min_p.iter()) + .chain(self.max_p.iter()) + .chain(self.min_s.iter()) + .chain(self.max_s.iter()) .chain(once(self.overflow)) .chain(self.ch.iter()) .chain(self.ph.iter()) @@ -276,7 +354,9 @@ impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { } } -impl PublicInputCommon for PublicInputs<'_, Target, S> { +impl PublicInputCommon + for PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> +{ const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; fn register_args(&self, cb: &mut CBuilder) { @@ -284,43 +364,39 @@ impl PublicInputCommon for PublicInputs<'_, Target, S> { cb.register_public_inputs(self.v); cb.register_public_input(*self.count); cb.register_public_inputs(self.ops); - cb.register_public_inputs(self.i); - cb.register_public_inputs(self.min); - cb.register_public_inputs(self.max); - cb.register_public_inputs(self.ids); - cb.register_public_inputs(self.min_q); - cb.register_public_inputs(self.max_q); + cb.register_public_inputs(self.left_row); + cb.register_public_inputs(self.right_row); + cb.register_public_inputs(self.min_p); + cb.register_public_inputs(self.max_p); + cb.register_public_inputs(self.min_s); + cb.register_public_inputs(self.max_s); cb.register_public_input(*self.overflow); cb.register_public_inputs(self.ch); cb.register_public_inputs(self.ph); } } -impl PublicInputs<'_, Target, S> { +impl + PublicInputsFactory<'_, Target, S, UNIVERSAL_CIRCUIT> +{ pub fn tree_hash_target(&self) -> HashOutTarget { HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } /// Return the first output value as a `CurveTarget` pub fn first_value_as_curve_target(&self) -> CurveTarget { let targets = self.to_values_raw(); - CurveTarget::from_targets(&targets[..CURVE_TARGET_LEN]) + CurveOrU256Target::from_targets(targets).as_curve_target() } /// Return the first output value as a `UInt256Target` pub fn first_value_as_u256_target(&self) -> UInt256Target { - let targets = Self::truncate_slice_to_u256_raw(self.to_values_raw()); - UInt256Target::from_targets(targets) + let targets = self.to_values_raw(); + CurveOrU256Target::from_targets(targets).as_u256_target() } /// Return the `UInt256` targets for the last `S-1` values pub fn values_target(&self) -> [UInt256Target; S - 1] { - let targets = &self.to_values_raw()[CURVE_TARGET_LEN..]; - targets - .chunks(NUM_LIMBS) - .map(UInt256Target::from_targets) - .collect_vec() - .try_into() - .unwrap() + OutputValuesTarget::from_targets(self.to_values_raw()).other_outputs } /// Return the value as a `UInt256Target` at the specified index @@ -328,11 +404,7 @@ impl PublicInputs<'_, Target, S> { where [(); S - 1]:, { - if i == 0 { - self.first_value_as_u256_target() - } else { - self.values_target()[i - 1].clone() - } + OutputValuesTarget::from_targets(self.to_values_raw()).value_target_at_index(i) } pub fn num_matching_rows_target(&self) -> Target { @@ -343,64 +415,124 @@ impl PublicInputs<'_, Target, S> { self.to_ops_raw().try_into().unwrap() } - pub fn index_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_index_value_raw()) + pub fn min_primary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_primary_raw()) } - pub fn min_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_value_raw()) + pub fn max_primary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_primary_raw()) } - pub fn max_value_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_value_raw()) + pub fn overflow_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.to_overflow_raw()) } - pub fn index_ids_target(&self) -> [Target; 2] { - self.to_index_ids_raw().try_into().unwrap() + pub fn computational_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } - pub fn min_query_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_min_query_raw()) + pub fn placeholder_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } +} - pub fn max_query_target(&self) -> UInt256Target { - UInt256Target::from_targets(self.to_max_query_raw()) +impl PublicInputsQueryCircuits<'_, Target, S> { + pub(crate) fn left_boundary_row_target(&self) -> BoundaryRowDataTarget { + BoundaryRowDataTarget::from_targets(self.to_left_row_raw()) } - pub fn overflow_flag_target(&self) -> BoolTarget { - BoolTarget::new_unsafe(*self.to_overflow_raw()) + pub(crate) fn right_boundary_row_target(&self) -> BoundaryRowDataTarget { + BoundaryRowDataTarget::from_targets(self.to_right_row_raw()) } - pub fn computational_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + pub(crate) fn to_row_chunk_target(&self) -> RowChunkDataTarget + where + [(); S - 1]:, + { + RowChunkDataTarget:: { + left_boundary_row: self.left_boundary_row_target(), + right_boundary_row: self.right_boundary_row_target(), + chunk_outputs: UniversalQueryOutputWires { + tree_hash: self.tree_hash_target(), + values: OutputValuesTarget::from_targets(self.to_values_raw()), + count: self.num_matching_rows_target(), + num_overflows: self.overflow_flag_target().target, + }, + } } - pub fn placeholder_hash_target(&self) -> HashOutTarget { - HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + /// Build an instance of `RowChunkDataTarget` from `self`; if `is_non_dummy_chunk` is + /// `false`, then build an instance of `RowChunkDataTarget` for a dummy chunk + pub(crate) fn to_dummy_row_chunk_target( + &self, + b: &mut CBuilder, + is_non_dummy_chunk: BoolTarget, + ) -> RowChunkDataTarget + where + [(); S - 1]:, + { + let dummy_values = compute_dummy_output_targets(b, &self.operation_ids_target()); + let output_values = self + .to_values_raw() + .iter() + .zip_eq(&dummy_values) + .map(|(&value, &dummy_value)| b.select(is_non_dummy_chunk, value, dummy_value)) + .collect_vec(); + + RowChunkDataTarget:: { + left_boundary_row: self.left_boundary_row_target(), + right_boundary_row: self.right_boundary_row_target(), + chunk_outputs: UniversalQueryOutputWires { + tree_hash: self.tree_hash_target(), + values: OutputValuesTarget::from_targets(&output_values), + // `count` is zeroed if chunk is dummy + count: b.mul(self.num_matching_rows_target(), is_non_dummy_chunk.target), + num_overflows: self.overflow_flag_target().target, + }, + } + } + + pub fn min_secondary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_secondary_raw()) + } + + pub fn max_secondary_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_secondary_raw()) + } +} + +impl PublicInputsUniversalCircuit<'_, Target, S> { + pub fn secondary_index_value_target(&self) -> UInt256Target { + // secondary index value is found in `self.min_s` for + // `PublicInputsUniversalCircuit` + UInt256Target::from_targets(self.min_s) + } + + pub fn primary_index_value_target(&self) -> UInt256Target { + // primary index value is found in `self.max_s` for + // `PublicInputsUniversalCircuit` + UInt256Target::from_targets(self.max_s) } } -impl PublicInputs<'_, F, S> { +impl PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> +where + [(); S - 1]:, +{ pub fn tree_hash(&self) -> HashOut { HashOut::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length } pub fn first_value_as_curve_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(&self.to_values_raw()[..CURVE_TARGET_LEN]) + OutputValues::::from_fields(self.to_values_raw()).first_value_as_curve_point() } pub fn first_value_as_u256(&self) -> U256 { - let fields = Self::truncate_slice_to_u256_raw(self.to_values_raw()); - U256::from_fields(fields) + OutputValues::::from_fields(self.to_values_raw()).first_value_as_u256() } pub fn values(&self) -> [U256; S - 1] { - self.to_values_raw()[CURVE_TARGET_LEN..] - .chunks(NUM_LIMBS) - .map(U256::from_fields) - .collect_vec() - .try_into() - .unwrap() + OutputValues::::from_fields(self.to_values_raw()).other_outputs } /// Return the value as a UInt256 at the specified index @@ -408,11 +540,7 @@ impl PublicInputs<'_, F, S> { where [(); S - 1]:, { - if i == 0 { - self.first_value_as_u256() - } else { - self.values()[i - 1] - } + OutputValues::::from_fields(self.to_values_raw()).value_at_index(i) } pub fn num_matching_rows(&self) -> F { @@ -423,28 +551,12 @@ impl PublicInputs<'_, F, S> { self.to_ops_raw().try_into().unwrap() } - pub fn index_value(&self) -> U256 { - U256::from_fields(self.to_index_value_raw()) - } - - pub fn min_value(&self) -> U256 { - U256::from_fields(self.to_min_value_raw()) - } - - pub fn max_value(&self) -> U256 { - U256::from_fields(self.to_max_value_raw()) + pub fn min_primary(&self) -> U256 { + U256::from_fields(self.to_min_primary_raw()) } - pub fn index_ids(&self) -> [F; 2] { - self.to_index_ids_raw().try_into().unwrap() - } - - pub fn min_query_value(&self) -> U256 { - U256::from_fields(self.to_min_query_raw()) - } - - pub fn max_query_value(&self) -> U256 { - U256::from_fields(self.to_max_query_raw()) + pub fn max_primary(&self) -> U256 { + U256::from_fields(self.to_max_primary_raw()) } pub fn overflow_flag(&self) -> bool { @@ -462,9 +574,32 @@ impl PublicInputs<'_, F, S> { } } -#[cfg(test)] -mod tests { +impl PublicInputsQueryCircuits<'_, F, S> { + pub fn min_secondary(&self) -> U256 { + U256::from_fields(self.to_min_secondary_raw()) + } + + pub fn max_secondary(&self) -> U256 { + U256::from_fields(self.to_max_secondary_raw()) + } +} + +impl PublicInputsUniversalCircuit<'_, F, S> { + pub fn secondary_index_value(&self) -> U256 { + // secondary index value is found in `self.min_s` for + // `PublicInputsUniversalCircuit` + U256::from_fields(self.min_s) + } + pub fn primary_index_value(&self) -> U256 { + // primary index value is found in `self.max_s` for + // `PublicInputsUniversalCircuit` + U256::from_fields(self.max_s) + } +} + +#[cfg(test)] +pub(crate) mod tests { use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; use mp2_test::{ circuit::{run_circuit, UserCircuit}, @@ -478,9 +613,7 @@ mod tests { plonk::circuit_builder::CircuitBuilder, }; - use crate::query::public_inputs::QueryPublicInputs; - - use super::PublicInputs; + use super::{PublicInputsQueryCircuits, QueryPublicInputs}; const S: usize = 10; #[derive(Clone, Debug)] @@ -492,8 +625,10 @@ mod tests { type Wires = Vec; fn build(c: &mut CircuitBuilder) -> Self::Wires { - let targets = c.add_virtual_target_arr::<{ PublicInputs::::total_len() }>(); - let pi_targets = PublicInputs::::from_slice(targets.as_slice()); + let targets = c + .add_virtual_target_arr::<{ PublicInputsQueryCircuits::::total_len() }>( + ); + let pi_targets = PublicInputsQueryCircuits::::from_slice(targets.as_slice()); pi_targets.register_args(c); pi_targets.to_vec() } @@ -504,60 +639,65 @@ mod tests { } #[test] - fn test_query_public_inputs() { - let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); - let pis = PublicInputs::::from_slice(pis_raw.as_slice()); + fn test_batching_query_public_inputs() { + let pis_raw: Vec = + random_vector::(PublicInputsQueryCircuits::::total_len()).to_fields(); + let pis = PublicInputsQueryCircuits::::from_slice(pis_raw.as_slice()); // check public inputs are constructed correctly assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::TreeHash)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::TreeHash)], pis.to_hash_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OutputValues)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::OutputValues)], pis.to_values_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::NumMatching)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::NumMatching)], &[*pis.to_count_raw()], ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OpIds)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::OpIds)], pis.to_ops_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexValue)], - pis.to_index_value_raw(), + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::LeftBoundaryRow)], + pis.to_left_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinValue)], - pis.to_min_value_raw(), + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::RightBoundaryRow)], + pis.to_right_row_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxValue)], - pis.to_max_value_raw(), + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MinPrimary)], + pis.to_min_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinQuery)], - pis.to_min_query_raw(), + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MaxPrimary)], + pis.to_max_primary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxQuery)], - pis.to_max_query_raw(), + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MinSecondary)], + pis.to_min_secondary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexIds)], - pis.to_index_ids_raw(), + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::MaxSecondary)], + pis.to_max_secondary_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], + &pis_raw[PublicInputsQueryCircuits::::to_range(QueryPublicInputs::Overflow)], &[*pis.to_overflow_raw()], ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::ComputationalHash)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::ComputationalHash)], pis.to_computational_hash_raw(), ); assert_eq!( - &pis_raw[PublicInputs::::to_range(QueryPublicInputs::PlaceholderHash)], + &pis_raw + [PublicInputsQueryCircuits::::to_range(QueryPublicInputs::PlaceholderHash)], pis.to_placeholder_hash_raw(), ); // use public inputs in circuit diff --git a/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs new file mode 100644 index 000000000..8cd1d1ec4 --- /dev/null +++ b/verifiable-db/src/query/row_chunk_gadgets/aggregate_chunks.rs @@ -0,0 +1,770 @@ +use mp2_common::{ + types::CBuilder, + u256::UInt256Target, + utils::{FromTargets, SelectTarget}, +}; +use plonky2::iop::target::{BoolTarget, Target}; + +use crate::query::universal_circuit::universal_query_gadget::{ + OutputValuesTarget, UniversalQueryOutputWires, +}; + +use super::{consecutive_rows::are_consecutive_rows, BoundaryRowDataTarget, RowChunkDataTarget}; + +/// This method aggregates the 2 chunks `first` and `second`, also checking +/// that they are consecutive. The returned aggregated chunk will +/// correspond to first if `is_second_dummy` flag is true +#[allow(dead_code)] // only in this PR +pub(crate) fn aggregate_chunks( + b: &mut CBuilder, + first: &RowChunkDataTarget, + second: &RowChunkDataTarget, + primary_query_bounds: (&UInt256Target, &UInt256Target), + secondary_query_bounds: (&UInt256Target, &UInt256Target), + ops: &[Target; MAX_NUM_RESULTS], + is_second_non_dummy: &BoolTarget, +) -> RowChunkDataTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + let (min_query_primary, max_query_primary) = primary_query_bounds; + let (min_query_secondary, max_query_secondary) = secondary_query_bounds; + let _true = b._true(); + // check that right boundary row of chunk1 and left boundary row of chunk2 + // are consecutive + let are_consecutive = are_consecutive_rows( + b, + &first.right_boundary_row, + &second.left_boundary_row, + min_query_primary, + max_query_primary, + min_query_secondary, + max_query_secondary, + ); + // assert that the 2 chunks are consecutive only if the second one is not dummy + let are_consecutive = b.and(are_consecutive, *is_second_non_dummy); + b.connect(are_consecutive.target, is_second_non_dummy.target); + + // check the same root of the index tree is employed in both chunks to prove + // membership of rows in the chunks + b.connect_hashes( + first.chunk_outputs.tree_hash, + second.chunk_outputs.tree_hash, + ); + // sum the number of matching rows of the 2 chunks + let count = b.add(first.chunk_outputs.count, second.chunk_outputs.count); + + // aggregate output values. Note that we can aggregate outputs also if chunk2 is + // dummy, since the universal queyr gadget guarantees that dummy rows output + // values won't affect the final output values + let mut output_values = vec![]; + let values = [ + first.chunk_outputs.values.clone(), + second.chunk_outputs.values.clone(), + ]; + + let mut num_overflows = b.add( + first.chunk_outputs.num_overflows, + second.chunk_outputs.num_overflows, + ); + for (i, op) in ops.iter().enumerate() { + let (output, overflows) = OutputValuesTarget::aggregate_outputs(b, &values, *op, i); + output_values.extend_from_slice(&output); + num_overflows = b.add(num_overflows, overflows); + } + + RowChunkDataTarget { + left_boundary_row: first.left_boundary_row.clone(), + right_boundary_row: // if `is_second_non_dummy`, then the right boundary row of the aggregated chunk will + // be the right boundary row of second chunk, otherwise we keep right boundary row of first chunk for the + // aggregated chunk + BoundaryRowDataTarget::select( + b, + is_second_non_dummy, + &second.right_boundary_row, + &first.right_boundary_row, + ), + chunk_outputs: UniversalQueryOutputWires { + tree_hash: second.chunk_outputs.tree_hash, // we check it's the same between the 2 chunks + values: OutputValuesTarget::from_targets(&output_values), + count, + num_overflows, + }, + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{ + array::ToField, + check_panic, + types::{CBuilder, HashOutput}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::{FromFields, ToFields, ToTargets}, + C, D, F, + }; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::gen_random_u256, + }; + use plonky2::{ + field::types::{Field, PrimeField64, Sample}, + hash::hash_types::{HashOut, HashOutTarget}, + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::{circuit_builder::CircuitBuilder, config::GenericHashOut}, + }; + use rand::thread_rng; + + use crate::{ + query::{ + computational_hash_ids::{AggregationOperation, Identifiers}, + merkle_path::{ + tests::{build_node, generate_test_tree}, + MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, NeighborInfo, + }, + public_inputs::PublicInputsQueryCircuits, + row_chunk_gadgets::{ + tests::RowChunkData, BoundaryRowData, BoundaryRowDataTarget, BoundaryRowNodeInfo, + BoundaryRowNodeInfoTarget, RowChunkDataTarget, + }, + universal_circuit::universal_query_gadget::{ + OutputValues, OutputValuesTarget, UniversalQueryOutputWires, + }, + utils::{tests::aggregate_output_values, ChildPosition, NodeInfo}, + }, + test_utils::random_aggregation_operations, + }; + + use super::aggregate_chunks; + + const MAX_NUM_RESULTS: usize = 10; + const ROW_TREE_MAX_DEPTH: usize = 10; + const INDEX_TREE_MAX_DEPTH: usize = 3; + + /// Data structure for the input wires necessary to compute the `RowChunkData` associated + /// to a row chunk being tested + #[derive(Clone, Debug)] + struct RowChunkDataInputTarget { + left_boundary_row_path: MerklePathWithNeighborsTargetInputs, + left_boundary_index_path: MerklePathWithNeighborsTargetInputs, + left_boundary_row_value: UInt256Target, + left_boundary_row_subtree_hash: HashOutTarget, + left_boundary_index_value: UInt256Target, + right_boundary_row_path: MerklePathWithNeighborsTargetInputs, + right_boundary_index_path: MerklePathWithNeighborsTargetInputs, + right_boundary_row_value: UInt256Target, + right_boundary_row_subtree_hash: HashOutTarget, + right_boundary_index_value: UInt256Target, + chunk_count: Target, + chunk_num_overflows: Target, + chunk_output_values: OutputValuesTarget, + } + + /// Data structure for input values necessary to compute the `RowChunkData` associated + /// to a row chunk being tested + #[derive(Clone, Debug)] + struct RowChunkDataInput { + left_boundary_row_path: MerklePathWithNeighborsGadget, + left_boundary_row_node: NodeInfo, + left_boundary_index_path: MerklePathWithNeighborsGadget, + left_boundary_index_node: NodeInfo, + right_boundary_row_path: MerklePathWithNeighborsGadget, + right_boundary_row_node: NodeInfo, + right_boundary_index_path: MerklePathWithNeighborsGadget, + right_boundary_index_node: NodeInfo, + chunk_count: F, + chunk_num_overflows: F, + chunk_output_values: OutputValues, + } + + impl RowChunkDataInput { + fn build( + b: &mut CBuilder, + primary_index_id: Target, + secondary_index_id: Target, + ) -> (RowChunkDataInputTarget, RowChunkDataTarget) { + let [left_boundary_row_value, left_boundary_index_value, right_boundary_row_value, right_boundary_index_value] = + b.add_virtual_u256_arr_unsafe(); + let [left_boundary_row_subtree_hash, right_boundary_row_subtree_hash] = + array::from_fn(|_| b.add_virtual_hash()); + let left_boundary_row_path = MerklePathWithNeighborsGadget::build( + b, + left_boundary_row_value.clone(), + left_boundary_row_subtree_hash, + secondary_index_id, + ); + let left_boundary_index_path = MerklePathWithNeighborsGadget::build( + b, + left_boundary_index_value.clone(), + left_boundary_row_path.root, + primary_index_id, + ); + let right_boundary_row_path = MerklePathWithNeighborsGadget::build( + b, + right_boundary_row_value.clone(), + right_boundary_row_subtree_hash, + secondary_index_id, + ); + let right_boundary_index_path = MerklePathWithNeighborsGadget::build( + b, + right_boundary_index_value.clone(), + right_boundary_row_path.root, + primary_index_id, + ); + + // Enforce that both boundary rows belong to the same tree + b.connect_hashes( + left_boundary_index_path.root, + right_boundary_index_path.root, + ); + + let left_boundary_row_info = BoundaryRowNodeInfoTarget::from(&left_boundary_row_path); + let left_boundary_index_info = + BoundaryRowNodeInfoTarget::from(&left_boundary_index_path); + let right_boundary_row_info = BoundaryRowNodeInfoTarget::from(&right_boundary_row_path); + let right_boundary_index_info = + BoundaryRowNodeInfoTarget::from(&right_boundary_index_path); + + let chunk_inputs = RowChunkDataInputTarget { + left_boundary_row_path: left_boundary_row_path.inputs, + left_boundary_index_path: left_boundary_index_path.inputs, + left_boundary_row_value, + left_boundary_row_subtree_hash, + left_boundary_index_value, + right_boundary_row_path: right_boundary_row_path.inputs, + right_boundary_index_path: right_boundary_index_path.inputs, + right_boundary_row_value, + right_boundary_row_subtree_hash, + right_boundary_index_value, + chunk_count: b.add_virtual_target(), + chunk_num_overflows: b.add_virtual_target(), + chunk_output_values: OutputValuesTarget::build(b), + }; + + let row_chunk = RowChunkDataTarget { + left_boundary_row: BoundaryRowDataTarget { + row_node_info: left_boundary_row_info, + index_node_info: left_boundary_index_info, + }, + right_boundary_row: BoundaryRowDataTarget { + row_node_info: right_boundary_row_info, + index_node_info: right_boundary_index_info, + }, + chunk_outputs: UniversalQueryOutputWires { + tree_hash: right_boundary_index_path.root, + values: chunk_inputs.chunk_output_values.clone(), + count: chunk_inputs.chunk_count, + num_overflows: chunk_inputs.chunk_num_overflows, + }, + }; + + (chunk_inputs, row_chunk) + } + + fn assign(&self, pw: &mut PartialWitness, wires: &RowChunkDataInputTarget) { + self.left_boundary_row_path + .assign(pw, &wires.left_boundary_row_path); + self.left_boundary_index_path + .assign(pw, &wires.left_boundary_index_path); + self.right_boundary_row_path + .assign(pw, &wires.right_boundary_row_path); + self.right_boundary_index_path + .assign(pw, &wires.right_boundary_index_path); + [ + ( + &wires.left_boundary_row_value, + self.left_boundary_row_node.value, + ), + ( + &wires.left_boundary_index_value, + self.left_boundary_index_node.value, + ), + ( + &wires.right_boundary_row_value, + self.right_boundary_row_node.value, + ), + ( + &wires.right_boundary_index_value, + self.right_boundary_index_node.value, + ), + ] + .into_iter() + .for_each(|(t, v)| pw.set_u256_target(t, v)); + [ + ( + wires.left_boundary_row_subtree_hash, + self.left_boundary_row_node.embedded_tree_hash, + ), + ( + wires.right_boundary_row_subtree_hash, + self.right_boundary_row_node.embedded_tree_hash, + ), + ] + .into_iter() + .for_each(|(t, v)| pw.set_hash_target(t, v)); + [ + (wires.chunk_count, self.chunk_count), + (wires.chunk_num_overflows, self.chunk_num_overflows), + ] + .into_iter() + .for_each(|(t, v)| pw.set_target(t, v)); + wires + .chunk_output_values + .set_target(pw, &self.chunk_output_values); + } + } + + #[derive(Clone, Debug)] + struct TestAggregateChunkWires { + first: RowChunkDataInputTarget, + second: RowChunkDataInputTarget, + min_query_primary: UInt256Target, + max_query_primary: UInt256Target, + min_query_secondary: UInt256Target, + max_query_secondary: UInt256Target, + primary_index_id: Target, + secondary_index_id: Target, + ops: [Target; MAX_NUM_RESULTS], + is_second_non_dummy: BoolTarget, + } + #[derive(Clone, Debug)] + struct TestAggregateChunks { + first: RowChunkDataInput, + second: RowChunkDataInput, + min_query_primary: Option, + max_query_primary: Option, + min_query_secondary: Option, + max_query_secondary: Option, + primary_index_id: F, + secondary_index_id: F, + ops: [F; MAX_NUM_RESULTS], + is_second_dummy: bool, + } + + impl UserCircuit for TestAggregateChunks { + type Wires = TestAggregateChunkWires; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let [primary_index_id, secondary_index_id] = c.add_virtual_target_arr(); + let (first_chunk_inputs, first_chunk_data) = + RowChunkDataInput::build(c, primary_index_id, secondary_index_id); + let (second_chunk_inputs, second_chunk_data) = + RowChunkDataInput::build(c, primary_index_id, secondary_index_id); + let [min_query_primary, max_query_primary, min_query_secondary, max_query_secondary] = + c.add_virtual_u256_arr_unsafe(); + let ops = c.add_virtual_target_arr(); + let is_second_non_dummy = c.add_virtual_bool_target_unsafe(); + let aggregated_chunk = aggregate_chunks( + c, + &first_chunk_data, + &second_chunk_data, + (&min_query_primary, &max_query_primary), + (&min_query_secondary, &max_query_secondary), + &ops, + &is_second_non_dummy, + ); + + c.register_public_inputs(&aggregated_chunk.to_targets()); + + TestAggregateChunkWires { + first: first_chunk_inputs, + second: second_chunk_inputs, + min_query_primary, + max_query_primary, + min_query_secondary, + max_query_secondary, + primary_index_id, + secondary_index_id, + ops, + is_second_non_dummy, + } + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.first.assign(pw, &wires.first); + self.second.assign(pw, &wires.second); + [ + ( + &wires.min_query_primary, + self.min_query_primary.unwrap_or(U256::ZERO), + ), + ( + &wires.max_query_primary, + self.max_query_primary.unwrap_or(U256::MAX), + ), + ( + &wires.min_query_secondary, + self.min_query_secondary.unwrap_or(U256::ZERO), + ), + ( + &wires.max_query_secondary, + self.max_query_secondary.unwrap_or(U256::MAX), + ), + ] + .into_iter() + .for_each(|(t, v)| pw.set_u256_target(t, v)); + [ + (wires.primary_index_id, self.primary_index_id), + (wires.secondary_index_id, self.secondary_index_id), + ] + .into_iter() + .chain(wires.ops.into_iter().zip(self.ops)) + .for_each(|(t, v)| pw.set_target(t, v)); + pw.set_bool_target(wires.is_second_non_dummy, !self.is_second_dummy); + } + } + + fn test_aggregate_chunks(ops: [F; MAX_NUM_RESULTS]) { + let [primary_index_id, secondary_index_id] = F::rand_array(); + // generate a single rows tree that will contain the row chunks to be aggregated: no need to + // use multiple rows tree in this test, as we already test `are_consecutive_rows` gadget. + // The generated tree will have the following shape + // A + // B C + // D G + // E F + let [node_a, node_b, node_c, node_d, node_e, node_f, node_g] = + generate_test_tree(secondary_index_id, None); + let rows_tree_root = HashOutput::from(node_a.compute_node_hash(secondary_index_id)); + // build the node of the index tree that stores the rows tree being generated + let rng = &mut thread_rng(); + let index_node = build_node( + None, + None, + gen_random_u256(rng), + rows_tree_root, + primary_index_id, + ); + let root = index_node.compute_node_hash(primary_index_id); + + // generate the output values associated to each chunk + let inputs = PublicInputsQueryCircuits::::sample_from_ops::<2>(&ops); + let [(first_chunk_count, first_chunk_outputs, fist_chunk_num_overflows), (second_chunk_count, second_chunk_outputs, second_chunk_num_overflows)] = + inputs + .into_iter() + .map(|input| { + let pis = PublicInputsQueryCircuits::::from_slice( + input.as_slice(), + ); + ( + pis.num_matching_rows(), + OutputValues::from_fields(pis.to_values_raw()), + F::from_canonical_u8(pis.overflow_flag() as u8), + ) + }) + .collect_vec() + .try_into() + .unwrap(); + + // the first row chunk for this test is given by nodes `B`, `D`, `E` and `F`. So left boundary row is `E` and + // right boundary row is `B` + let path_e = vec![ + (node_d, ChildPosition::Left), + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let node_f_hash = HashOutput::from(node_f.compute_node_hash(secondary_index_id)); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(secondary_index_id)); + let siblings_e = vec![Some(node_f_hash), None, Some(node_c_hash)]; + let merkle_path_inputs_e = MerklePathWithNeighborsGadget::::new( + &path_e, + &siblings_e, + &node_e, + [None, None], // it's a leaf node + ) + .unwrap(); + + let path_b = vec![(node_a, ChildPosition::Left)]; + let siblings_b = vec![Some(node_c_hash)]; + let merkle_path_inputs_b = MerklePathWithNeighborsGadget::::new( + &path_b, + &siblings_b, + &node_b, + [Some(node_d), None], + ) + .unwrap(); + + let index_node_path = vec![]; + let index_node_siblings = vec![]; + let index_node_merkle_path = MerklePathWithNeighborsGadget::::new( + &index_node_path, + &index_node_siblings, + &index_node, + [None, None], + ) + .unwrap(); + let first_chunk = RowChunkDataInput { + left_boundary_row_path: merkle_path_inputs_e, + left_boundary_row_node: node_e, + left_boundary_index_path: index_node_merkle_path, + left_boundary_index_node: index_node, + right_boundary_row_path: merkle_path_inputs_b, + right_boundary_row_node: node_b, + right_boundary_index_path: index_node_merkle_path, + right_boundary_index_node: index_node, + chunk_count: first_chunk_count, + chunk_num_overflows: fist_chunk_num_overflows, + chunk_output_values: first_chunk_outputs.clone(), + }; + + // the second row chunk for this test is given by nodes `A`, `C`, and `G`. So left boundary row is `A` and + // right boundary row is `G` + let path_a = vec![]; + let siblings_a = vec![]; + let merkle_path_inputs_a = MerklePathWithNeighborsGadget::::new( + &path_a, + &siblings_a, + &node_a, + [Some(node_b), Some(node_c)], + ) + .unwrap(); + + let path_g = vec![ + (node_c, ChildPosition::Right), + (node_a, ChildPosition::Right), + ]; + let node_b_hash = HashOutput::from(node_b.compute_node_hash(secondary_index_id)); + let siblings_g = vec![None, Some(node_b_hash)]; + let merkle_path_inputs_g = MerklePathWithNeighborsGadget::::new( + &path_g, + &siblings_g, + &node_g, + [None, None], + ) + .unwrap(); + + let second_chunk = RowChunkDataInput { + left_boundary_row_path: merkle_path_inputs_a, + left_boundary_row_node: node_a, + left_boundary_index_path: index_node_merkle_path, + left_boundary_index_node: index_node, + right_boundary_row_path: merkle_path_inputs_g, + right_boundary_row_node: node_g, + right_boundary_index_path: index_node_merkle_path, + right_boundary_index_node: index_node, + chunk_count: second_chunk_count, + chunk_num_overflows: second_chunk_num_overflows, + chunk_output_values: second_chunk_outputs.clone(), + }; + + let circuit = TestAggregateChunks { + first: first_chunk.clone(), + second: second_chunk.clone(), + min_query_primary: None, + max_query_primary: None, + min_query_secondary: None, + max_query_secondary: None, + primary_index_id, + secondary_index_id, + ops, + is_second_dummy: false, + }; + + let proof = run_circuit::(circuit); + // compute expected aggregated chunk + let node_e_info = BoundaryRowNodeInfo { + end_node_hash: node_e.compute_node_hash(secondary_index_id), + predecessor_info: NeighborInfo::new_dummy_predecessor(), + successor_info: NeighborInfo::new( + node_d.value, + Some(node_d.compute_node_hash(secondary_index_id)), + ), + }; + let index_node_info = BoundaryRowNodeInfo { + end_node_hash: root, + predecessor_info: NeighborInfo::new_dummy_predecessor(), + successor_info: NeighborInfo::new_dummy_successor(), + }; + let node_g_info = BoundaryRowNodeInfo { + end_node_hash: node_g.compute_node_hash(secondary_index_id), + predecessor_info: NeighborInfo::new( + node_c.value, + Some(HashOut::from_bytes((&node_c_hash).into())), + ), + successor_info: NeighborInfo::new_dummy_successor(), + }; + let (expected_outputs, expected_num_overflows) = { + let outputs = [first_chunk_outputs.clone(), second_chunk_outputs.clone()]; + let mut num_overflows = fist_chunk_num_overflows + second_chunk_num_overflows; + let outputs = ops + .into_iter() + .enumerate() + .flat_map(|(i, op)| { + let (out, overflows) = aggregate_output_values(i, &outputs, op); + num_overflows += F::from_canonical_u32(overflows); + out + }) + .collect_vec(); + ( + OutputValues::from_fields(&outputs), + num_overflows.to_canonical_u64(), + ) + }; + let expected_count = (first_chunk_count + second_chunk_count).to_canonical_u64(); + + let expected_chunk = RowChunkData:: { + left_boundary_row: BoundaryRowData { + row_node_info: node_e_info.clone(), + index_node_info: index_node_info.clone(), + }, + right_boundary_row: BoundaryRowData { + row_node_info: node_g_info, + index_node_info: index_node_info.clone(), + }, + chunk_tree_hash: root, + output_values: expected_outputs.clone(), + num_overflows: expected_num_overflows, + count: expected_count, + }; + + assert_eq!(proof.public_inputs, expected_chunk.to_fields()); + + // test with second chunk being dummy; we use a non-consecutive chunk as the dummy one: the row chunk + // given by node_G only + let second_chunk = RowChunkDataInput { + left_boundary_row_path: merkle_path_inputs_g, + left_boundary_row_node: node_g, + left_boundary_index_path: index_node_merkle_path, + left_boundary_index_node: index_node, + right_boundary_row_path: merkle_path_inputs_g, + right_boundary_row_node: node_g, + right_boundary_index_path: index_node_merkle_path, + right_boundary_index_node: index_node, + chunk_count: second_chunk_count, + chunk_num_overflows: second_chunk_num_overflows, + chunk_output_values: second_chunk_outputs.clone(), + }; + let circuit = TestAggregateChunks { + first: first_chunk.clone(), + second: second_chunk.clone(), + min_query_primary: None, + max_query_primary: None, + min_query_secondary: None, + max_query_secondary: None, + primary_index_id, + secondary_index_id, + ops, + is_second_dummy: true, // we set the second chunk to dummy + }; + + let proof = run_circuit::(circuit); + // compute expected aggregated chunk + // since we aggregate with a dummy chunk, we expect right boundary row to be the same as the + // first chunk, that is node_B + let node_b_info = BoundaryRowNodeInfo { + end_node_hash: HashOut::from_bytes((&node_b_hash).into()), + predecessor_info: NeighborInfo::new(node_f.value, None), + successor_info: NeighborInfo::new( + node_a.value, + Some(HashOut::from_bytes((&rows_tree_root).into())), + ), + }; + let expected_chunk = RowChunkData:: { + left_boundary_row: BoundaryRowData { + row_node_info: node_e_info, + index_node_info: index_node_info.clone(), + }, + right_boundary_row: BoundaryRowData { + row_node_info: node_b_info, + index_node_info: index_node_info.clone(), + }, + chunk_tree_hash: root, + output_values: expected_outputs.clone(), + num_overflows: expected_num_overflows, + count: expected_count, + }; + assert_eq!(proof.public_inputs, expected_chunk.to_fields()); + + // negative test: check that we cannot aggregate non-consecutive non-dummy chunks + let circuit = TestAggregateChunks { + first: first_chunk.clone(), + second: second_chunk.clone(), + min_query_primary: None, + max_query_primary: None, + min_query_secondary: None, + max_query_secondary: None, + primary_index_id, + secondary_index_id, + ops, + is_second_dummy: false, + }; + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail when aggregating non-consecutive non-dummy chunks" + ); + + // negative test: check that we cannot aggregate a chunk with a wrong merkle root + // we build the second chunk employing a fake index node + let fake_node = build_node( + None, + None, + gen_random_u256(rng), + rows_tree_root, + primary_index_id, + ); + let fake_node_merkle_path = MerklePathWithNeighborsGadget::::new( + &[], + &[], + &fake_node, + [None, None], + ) + .unwrap(); + let second_chunk = RowChunkDataInput { + left_boundary_row_path: merkle_path_inputs_a, + left_boundary_row_node: node_a, + left_boundary_index_path: fake_node_merkle_path, + left_boundary_index_node: fake_node, + right_boundary_row_path: merkle_path_inputs_g, + right_boundary_row_node: node_g, + right_boundary_index_path: fake_node_merkle_path, + right_boundary_index_node: fake_node, + chunk_count: second_chunk_count, + chunk_num_overflows: second_chunk_num_overflows, + chunk_output_values: second_chunk_outputs.clone(), + }; + + let circuit = TestAggregateChunks { + first: first_chunk.clone(), + second: second_chunk.clone(), + min_query_primary: None, + max_query_primary: None, + min_query_secondary: None, + max_query_secondary: None, + primary_index_id, + secondary_index_id, + ops, + is_second_dummy: false, + }; + + check_panic!( + || run_circuit::(circuit), + "circuit didn't fail when aggregating chunks with different merkle roots" + ); + } + + #[test] + fn test_aggregate_chunks_random_operations() { + let ops = random_aggregation_operations(); + + test_aggregate_chunks(ops); + } + + #[test] + fn test_aggregate_chunks_with_id_operation() { + // Generate the random operations. + let mut ops = random_aggregation_operations(); + + // Set the first operation to ID for testing the digest computation. + ops[0] = Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); + + test_aggregate_chunks(ops); + } +} diff --git a/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs new file mode 100644 index 000000000..984f6828d --- /dev/null +++ b/verifiable-db/src/query/row_chunk_gadgets/consecutive_rows.rs @@ -0,0 +1,1293 @@ +use mp2_common::{ + types::CBuilder, + u256::{CircuitBuilderU256, UInt256Target}, + utils::HashBuilder, + F, +}; +use plonky2::{field::types::Field, iop::target::BoolTarget}; + +use super::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget}; + +/// This methods checks whether two nodes `first` and `second` are consecutive, according +/// to the definition found in the docs +/// (https://www.notion.so/lagrangelabs/Aggregating-Query-Results-with-Individual-Merkle-Paths-10628d1c65a880b1b151d4ac017fa445?pvs=4#10d28d1c65a8804fb11ed5d14fa70ea3) +/// The query bounds provided as inputs refer to either the secondary or primary index, +/// depending on whether the nodes are in a rows tree or in the index tree. +/// The method returns 2 flags: +/// - The first one being true iff the 2 nodes are consecutive +/// - The second one being true iff the successor of first node is found and its value is in the range +/// specified by the query bounds provided as inputs +#[allow(dead_code)] // only in this PR +fn are_consecutive_nodes( + b: &mut CBuilder, + first: &BoundaryRowNodeInfoTarget, + second: &BoundaryRowNodeInfoTarget, + min_query_bound: &UInt256Target, + max_query_bound: &UInt256Target, + are_rows_tree_nodes: bool, +) -> (BoolTarget, BoolTarget) { + let mut are_consecutive = b._true(); + let first_node_successor_value = &first.successor_info.value; + // ensure that we don't prove nodes outside of the range: the successor of the + // first node must store a value bigger that `min_query_bound` + let bigger_than_min = b.is_less_or_equal_than_u256(min_query_bound, first_node_successor_value); + are_consecutive = b.and(are_consecutive, bigger_than_min); + // determine whether the successor (if any) of the first node stores a value in the query range or not; + // note that, since we previously checked that such value is >= min_query_bound, + // we only need to check whether this value is not dummy (i.e., if the successor exists) and if + // such value is <= max_query_bound + let smaller_than_max = + b.is_less_or_equal_than_u256(first_node_successor_value, max_query_bound); + let first_node_succ_in_range = b.and(smaller_than_max, first.successor_info.is_found); + // if first_node_succ_in_range is true, and the successor of the first node was found in the path from + // such node to the root of the tree, then the hash of successor node will be placed in + // `first.successor_info.hash` by `MerklePathWithNeighborsGadget: therefore, we can check that `second` + // is consecutive of `first` by checking that `first.successor_info.hash` is the hash of the second node; + // otherwise, we cannot check right now that the 2 nodes are consecutive, we will do it later + let check_are_consecutive = b.and(first_node_succ_in_range, first.successor_info.is_in_path); + let is_second_node_successor = b.hash_eq(&first.successor_info.hash, &second.end_node_hash); + // update are_consecutive as `are_consecutive && is_second_node_successor`` if `check_are_consecutive` is true + let new_are_consecutive = b.and(are_consecutive, is_second_node_successor); + are_consecutive = BoolTarget::new_unsafe(b.select( + check_are_consecutive, + new_are_consecutive.target, + are_consecutive.target, + )); + // we now look at the predecessor of second node, matching it with first node in case the + // predecessor is found in the path of second node in the tree + let second_node_predecessor_value = &second.predecessor_info.value; + // ensure that we don't prove nodes outside of the range: the predecessor of the second + // node must store a value smaller that `max_query_bound`` + let smaller_than_max = + b.is_less_or_equal_than_u256(second_node_predecessor_value, max_query_bound); + are_consecutive = b.and(are_consecutive, smaller_than_max); + // determine whether the predecessor (if any) of the second node stores a value in the query range or not; + // note that, since we previously checked that such value is <= max_query_bound, + // we only need to check whether this value is not dummy (i.e., if the predecessor exists) and if + // such value is >= min_query_bound + let bigger_than_min = + b.is_less_or_equal_than_u256(min_query_bound, second_node_predecessor_value); + let second_node_pred_in_range = b.and(bigger_than_min, second.predecessor_info.is_found); + // if second_node_pred_in_range is true, and the predecessor of the second node was found in the path from + // such node to the root of the tree, then the hash of predecessor node will be placed in + // `second.predecessor_info.hash` by `MerklePathWithNeighborsGadget: therefore, we can check that `second` + // is consecutive of `first` by checking that `second.predecessor_info.hash` is the hash of the first node; + // otherwise, we cannot check right now that the 2 nodes are consecutive, and it necessarily means we have + // already done it before when checking that the successor of first node was the second node + let check_are_consecutive = b.and( + second_node_pred_in_range, + second.predecessor_info.is_in_path, + ); + let is_second_node_successor = b.hash_eq(&second.predecessor_info.hash, &first.end_node_hash); + // update are_consecutive as `are_consecutive && is_second_node_successor`` if `check_are_consecutive` is true + let new_are_consecutive = b.and(are_consecutive, is_second_node_successor); + are_consecutive = BoolTarget::new_unsafe(b.select( + check_are_consecutive, + new_are_consecutive.target, + are_consecutive.target, + )); + + // lastly, check that either successor of first node is located in the path, or the predecessor of second node + // is located in the path, which is necessarily true if the 2 nodes are consecutive. Note that we need to enforce + // this always if we need to "strictly" prove that 2 nodes are consecutive, which happens in the following cases: + // - if nodes are in the index tree + // - if nodes are in a rows tree, but `first_node_succ_in_range` is true. Indeed, if the successor of first node + // is out of range or doesn't exist, then it means that second node belongs to another rows tree, and so it cannot + // be a successor of first node in the same rows tree + let either_is_in_path = b.or( + first.successor_info.is_in_path, + second.predecessor_info.is_in_path, + ); + + if !are_rows_tree_nodes { + // in case of index tree, we need to enforce that `either_is_in_path` must be true + are_consecutive = b.and(are_consecutive, either_is_in_path); + // furthermore, we also need to enforce that first_node_succ_in_range and second_node_pred_in_range + // are both true; otherwise, the prover could provide the nodes at the boundary and prove them + // to be consecutive, which is not ok in the index tree + are_consecutive = b.and(are_consecutive, first_node_succ_in_range); + are_consecutive = b.and(are_consecutive, second_node_pred_in_range); + } else { + // in case of rows tree nodes, we need to check that `first_row_succ_in_range == second_row_pred_in_range`, + // which should always hold for consecutive rows since: + // - if the successor of first row is in range, then second row must be its successor + // in the same rows tree, and so the predecessor of second row is the first row itself, + // which is expected to be in range since we never need to prove nodes not in range + // but with a successor in range + // - if the successor of first row is out of range, then second row is expected to + // be a node in the "subsequent" rows tree (i.e., the rows tree stored in the index + // tree node which is the successor of the index tree node storing first row); this node + // can be either: + // - the first node of the "subsequent" rows tree with value >= min_secondary; + // in this case, the predecessor of second row is < min_secondary, and so out of range + // - if no such node can be found in the "subsequent" rows tree, then second row will be + // the last node in the "subsequent" rows tree with value < MIN_secondary; in + // this case, also its predecessor will necessarily be < MIN_secondary, and so + // out of range + // we first compute first_row_succ_in_range XOR second_row_pred_in_range: a XOR b = a + b - 2*a*b + let range_flags_sum = b.add( + first_node_succ_in_range.target, + second_node_pred_in_range.target, + ); + let minus_2 = F::NEG_ONE + F::NEG_ONE; + let range_flags_xor = b.arithmetic( + minus_2, + F::ONE, + first_node_succ_in_range.target, + second_node_pred_in_range.target, + range_flags_sum, + ); + // then, `are_consecutive = are_consecutive AND NOT(range_flags_xor) = are_consecutive - are_consecutive*range_flags_xor` + are_consecutive = BoolTarget::new_unsafe(b.arithmetic( + F::NEG_ONE, + F::ONE, + are_consecutive.target, + range_flags_xor, + are_consecutive.target, + )); + // in case of nodes in a rows tree, then we need to enforce that second is the successor of first only + // if the nodes are in the same rows tree, that is if `first_node_succ_in_range` is true + let new_are_consecutive = b.and(are_consecutive, either_is_in_path); + are_consecutive = BoolTarget::new_unsafe(b.select( + first_node_succ_in_range, + new_are_consecutive.target, + are_consecutive.target, + )); + } + + (are_consecutive, first_node_succ_in_range) +} + +/// This methods checks whether two rows `first` and `second` are consecutive, according +/// to the definition found in the docs +/// (https://www.notion.so/lagrangelabs/Aggregating-Query-Results-with-Individual-Merkle-Paths-10628d1c65a880b1b151d4ac017fa445?pvs=4#10d28d1c65a8804fb11ed5d14fa70ea3) +#[allow(dead_code)] // only in this PR +pub(crate) fn are_consecutive_rows( + b: &mut CBuilder, + first: &BoundaryRowDataTarget, + second: &BoundaryRowDataTarget, + min_query_primary: &UInt256Target, + max_query_primary: &UInt256Target, + min_query_secondary: &UInt256Target, + max_query_secondary: &UInt256Target, +) -> BoolTarget { + let (are_consecutive, first_row_succ_in_range) = are_consecutive_nodes( + b, + &first.row_node_info, + &second.row_node_info, + min_query_secondary, + max_query_secondary, + true, + ); + // at this stage we checked that the rows tree nodes storing first and second row are consecutive; we need + // to check also index tree consistency. + // if first_row_succ_in_range is true, then both the rows must be in the same rows tree; so, we simply + // check this and we are done + let is_same_rows_tree = b.hash_eq( + &first.index_node_info.end_node_hash, + &second.index_node_info.end_node_hash, + ); + + // otherwise, if the rows are in different rows trees, we need to check that they are stored in subsequent + // rows trees + let (are_index_nodes_consecutive, _) = are_consecutive_nodes( + b, + &first.index_node_info, + &second.index_node_info, + min_query_primary, + max_query_primary, + false, + ); + // compute the flag to be accumulated in `are_consecutive`, depending on whether the 2 rows are in the same + // rows tree or not (i.e., whether first_row_succ_in_range is true) + let index_tree_check = BoolTarget::new_unsafe(b.select( + first_row_succ_in_range, + is_same_rows_tree.target, + are_index_nodes_consecutive.target, + )); + b.and(are_consecutive, index_tree_check) +} + +#[cfg(test)] +mod tests { + use std::array; + + use alloy::primitives::U256; + use mp2_common::{ + types::HashOutput, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::TryIntoBool, + C, D, F, + }; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::gen_random_u256, + }; + use plonky2::{ + field::types::Sample, + hash::hash_types::HashOutTarget, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, + }; + use rand::thread_rng; + + use crate::query::{ + merkle_path::{ + tests::{build_node, generate_test_tree}, + MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs, + }, + utils::{ChildPosition, NodeInfo}, + }; + + use super::{ + are_consecutive_nodes, are_consecutive_rows, BoundaryRowDataTarget, + BoundaryRowNodeInfoTarget, + }; + + const ROW_TREE_MAX_DEPTH: usize = 10; + const INDEX_TREE_MAX_DEPTH: usize = 15; + + #[derive(Clone, Debug)] + struct TestConsecutiveNodes + where + [(); MAX_DEPTH - 1]:, + { + first_node_path: MerklePathWithNeighborsGadget, + first_node_info: NodeInfo, + second_node_path: MerklePathWithNeighborsGadget, + second_node_info: NodeInfo, + index_id: F, + min_query_bound: Option, + max_query_bound: Option, + } + + #[derive(Clone, Debug)] + struct TestConsecutiveNodesWires + where + [(); MAX_DEPTH - 1]:, + { + first_node_path: MerklePathWithNeighborsTargetInputs, + first_node_value: UInt256Target, + first_node_tree_hash: HashOutTarget, + second_node_path: MerklePathWithNeighborsTargetInputs, + second_node_value: UInt256Target, + second_node_tree_hash: HashOutTarget, + index_id: Target, + min_query_bound: UInt256Target, + max_query_bound: UInt256Target, + } + + impl TestConsecutiveNodesWires + where + [(); MAX_DEPTH - 1]:, + { + fn new( + c: &mut CircuitBuilder, + ) -> (Self, BoundaryRowNodeInfoTarget, BoundaryRowNodeInfoTarget) { + let [first_node_value, second_node_value, min_query_bound, max_query_bound] = + c.add_virtual_u256_arr_unsafe(); + let [first_node_tree_hash, second_node_tree_hash] = + array::from_fn(|_| c.add_virtual_hash()); + let index_id = c.add_virtual_target(); + let first_node_path = MerklePathWithNeighborsGadget::build( + c, + first_node_value.clone(), + first_node_tree_hash, + index_id, + ); + let second_node_path = MerklePathWithNeighborsGadget::build( + c, + second_node_value.clone(), + second_node_tree_hash, + index_id, + ); + + let first_node = BoundaryRowNodeInfoTarget { + end_node_hash: first_node_path.end_node_hash, + predecessor_info: first_node_path.predecessor_info, + successor_info: first_node_path.successor_info, + }; + let second_node = BoundaryRowNodeInfoTarget { + end_node_hash: second_node_path.end_node_hash, + predecessor_info: second_node_path.predecessor_info, + successor_info: second_node_path.successor_info, + }; + + ( + Self { + first_node_path: first_node_path.inputs, + first_node_value, + first_node_tree_hash, + second_node_path: second_node_path.inputs, + second_node_value, + second_node_tree_hash, + index_id, + min_query_bound, + max_query_bound, + }, + first_node, + second_node, + ) + } + } + + impl UserCircuit + for TestConsecutiveNodes + where + [(); MAX_DEPTH - 1]:, + { + type Wires = TestConsecutiveNodesWires; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let (wires, first_node, second_node) = TestConsecutiveNodesWires::new(c); + + let (are_consecutive, _) = are_consecutive_nodes( + c, + &first_node, + &second_node, + &wires.min_query_bound, + &wires.max_query_bound, + ROWS_TREE_NODES, + ); + + c.register_public_input(are_consecutive.target); + + wires + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.first_node_path.assign(pw, &wires.first_node_path); + self.second_node_path.assign(pw, &wires.second_node_path); + [ + (self.first_node_info.value, &wires.first_node_value), + (self.second_node_info.value, &wires.second_node_value), + ( + self.min_query_bound.unwrap_or(U256::ZERO), + &wires.min_query_bound, + ), + ( + self.max_query_bound.unwrap_or(U256::MAX), + &wires.max_query_bound, + ), + ] + .into_iter() + .for_each(|(value, target)| pw.set_u256_target(target, value)); + [ + ( + self.first_node_info.embedded_tree_hash, + wires.first_node_tree_hash, + ), + ( + self.second_node_info.embedded_tree_hash, + wires.second_node_tree_hash, + ), + ] + .into_iter() + .for_each(|(value, target)| pw.set_hash_target(target, value)); + pw.set_target(wires.index_id, self.index_id); + } + } + + #[derive(Clone, Debug)] + struct TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes, + index_tree_nodes: TestConsecutiveNodes, + } + + #[derive(Clone, Debug)] + struct TestConsecutiveRowsWires { + row_tree_nodes: TestConsecutiveNodesWires, + index_tree_nodes: TestConsecutiveNodesWires, + } + + impl UserCircuit for TestConsecutiveRows { + type Wires = TestConsecutiveRowsWires; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let (row_tree_nodes, first_row_node, second_row_node) = + TestConsecutiveNodesWires::new(c); + let (index_tree_nodes, first_index_node, second_index_node) = + TestConsecutiveNodesWires::new(c); + let first = BoundaryRowDataTarget { + row_node_info: first_row_node, + index_node_info: first_index_node, + }; + let second = BoundaryRowDataTarget { + row_node_info: second_row_node, + index_node_info: second_index_node, + }; + let are_consecutive = are_consecutive_rows( + c, + &first, + &second, + &index_tree_nodes.min_query_bound, + &index_tree_nodes.max_query_bound, + &row_tree_nodes.min_query_bound, + &row_tree_nodes.max_query_bound, + ); + + c.register_public_input(are_consecutive.target); + + TestConsecutiveRowsWires { + row_tree_nodes, + index_tree_nodes, + } + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.row_tree_nodes.prove(pw, &wires.row_tree_nodes); + self.index_tree_nodes.prove(pw, &wires.index_tree_nodes); + } + } + + #[test] + fn test_are_consecutive_nodes() { + let index_id = F::rand(); + // Build the following Merkle-tree + // A + // B C + // D G + // E F + let [node_a, node_b, node_c, node_d, node_e, node_f, node_g] = + generate_test_tree(index_id, None); + + // test that nodes F and D are consecutive + let path_f = vec![ + (node_d, ChildPosition::Right), // we start from the ancestor of the start node of the path + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let node_e_hash = HashOutput::from(node_e.compute_node_hash(index_id)); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(index_id)); + let siblings_f = vec![Some(node_e_hash), None, Some(node_c_hash)]; + let merkle_path_inputs_f = MerklePathWithNeighborsGadget::::new( + &path_f, + &siblings_f, + &node_f, + [None, None], // it's a leaf node + ) + .unwrap(); + let path_d = vec![(node_b, ChildPosition::Left), (node_a, ChildPosition::Left)]; + let siblings_d = vec![None, Some(node_c_hash)]; + let merkle_path_inputs_d = MerklePathWithNeighborsGadget::::new( + &path_d, + &siblings_d, + &node_d, + [Some(node_e), Some(node_f)], + ) + .unwrap(); + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_d, + first_node_info: node_d, + second_node_path: merkle_path_inputs_f, + second_node_info: node_f, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test that nodes A and C are consecutive + let path_a = vec![]; + let siblings_a = vec![]; + let merkle_path_inputs_a = MerklePathWithNeighborsGadget::::new( + &path_a, + &siblings_a, + &node_a, + [Some(node_b), Some(node_c)], + ) + .unwrap(); + let path_c = vec![(node_a, ChildPosition::Right)]; + let node_b_hash = HashOutput::from(node_b.compute_node_hash(index_id)); + let siblings_c = vec![Some(node_b_hash)]; + let merkle_path_inputs_c = MerklePathWithNeighborsGadget::::new( + &path_c, + &siblings_c, + &node_c, + [None, Some(node_g)], + ) + .unwrap(); + + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_a, + first_node_info: node_a, + second_node_path: merkle_path_inputs_c, + second_node_info: node_c, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test that nodes F and B are consecutive + let path_b = vec![(node_a, ChildPosition::Left)]; + let siblings_b = vec![Some(node_c_hash)]; + let merkle_path_inputs_b = MerklePathWithNeighborsGadget::::new( + &path_b, + &siblings_b, + &node_b, + [Some(node_d), None], + ) + .unwrap(); + + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_f, + first_node_info: node_f, + second_node_path: merkle_path_inputs_b, + second_node_info: node_b, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: E and F are not consecutive + let path_e = vec![ + (node_d, ChildPosition::Left), + (node_b, ChildPosition::Left), + (node_a, ChildPosition::Left), + ]; + let node_f_hash = HashOutput::from(node_f.compute_node_hash(index_id)); + let siblings_e = vec![Some(node_f_hash), None, Some(node_c_hash)]; + let merkle_path_inputs_e = MerklePathWithNeighborsGadget::::new( + &path_e, + &siblings_e, + &node_e, + [None, None], // it's a leaf node + ) + .unwrap(); + + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_e, + first_node_info: node_e, + second_node_path: merkle_path_inputs_f, + second_node_info: node_f, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: A and B are not consecutive (wrong order) + let path_a = vec![]; + let siblings_a = vec![]; + let merkle_path_inputs_a = MerklePathWithNeighborsGadget::::new( + &path_a, + &siblings_a, + &node_a, + [Some(node_b), Some(node_c)], + ) + .unwrap(); + + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_a, + first_node_info: node_a, + second_node_path: merkle_path_inputs_b, + second_node_info: node_b, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // but B and A are consecutive + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_b, + first_node_info: node_b, + second_node_path: merkle_path_inputs_a, + second_node_info: node_a, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // check that if we are proving nodes in a rows tree, then we can prove that C and D are consecutive + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_c, + first_node_info: node_c, + second_node_path: merkle_path_inputs_d, + second_node_info: node_d, + index_id, + min_query_bound: Some(node_d.value), + max_query_bound: Some(node_c.value), + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // instead, this is not possible if we are proving nodes in the index tree + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_c, + first_node_info: node_c, + second_node_path: merkle_path_inputs_f, + second_node_info: node_f, + index_id, + min_query_bound: Some(node_f.value), + max_query_bound: Some(node_c.value), + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // check that if we are proving nodes in a rows tree, then we can prove that G and E are consecutive + let path_g = vec![ + (node_c, ChildPosition::Right), + (node_a, ChildPosition::Right), + ]; + let siblings_g = vec![None, Some(node_b_hash)]; + let merkle_path_inputs_g = MerklePathWithNeighborsGadget::::new( + &path_g, + &siblings_g, + &node_g, + [None, None], + ) + .unwrap(); + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_g, + first_node_info: node_g, + second_node_path: merkle_path_inputs_e, + second_node_info: node_e, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // instead, this is not possible if we are proving nodes in the index tree + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_g, + first_node_info: node_g, + second_node_path: merkle_path_inputs_e, + second_node_info: node_e, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // check that, even in a rows tree, we cannot prove nodes which are not at the boundaries of the query range + // to be consecutive: test that C and E are not consecutive + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_c, + first_node_info: node_c, + second_node_path: merkle_path_inputs_e, + second_node_info: node_e, + index_id, + min_query_bound: None, + max_query_bound: None, + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // but check that C and E can be consecutive if they become at the boundaries of the range + let circuit = TestConsecutiveNodes:: { + first_node_path: merkle_path_inputs_c, + first_node_info: node_c, + second_node_path: merkle_path_inputs_e, + second_node_info: node_e, + index_id, + min_query_bound: None, + max_query_bound: Some(node_c.value), + }; + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + } + + #[test] + fn test_are_consecutive_rows() { + // structure representing the nodes of a tree generated with `generate_test_tree` + struct RowsTree { + node_a: NodeInfo, + node_b: NodeInfo, + node_c: NodeInfo, + node_d: NodeInfo, + node_e: NodeInfo, + node_f: NodeInfo, + node_g: NodeInfo, + } + + impl From<[NodeInfo; 7]> for RowsTree { + fn from(value: [NodeInfo; 7]) -> Self { + Self { + node_a: value[0], + node_b: value[1], + node_c: value[2], + node_d: value[3], + node_e: value[4], + node_f: value[5], + node_g: value[6], + } + } + } + + // we build an index tree with the following nodes: + // 1 + // 0 3 + // 2 4 + // where each node stores a rows tree generated with `generate_test_tree` + // generate values to be stored in index tree nodes + let rng = &mut thread_rng(); + let mut values: [U256; 5] = array::from_fn(|_| gen_random_u256(rng)); + values.sort(); + let secondary_index_id = F::rand(); + let primary_index_id = F::rand(); + // generate rows tree with values in decreasing order. This is a simple trick to ensure + // that min_secondary <= max_secondary when using custom query bounds in tests, as we will always + // take max_secondary from the set of values of rows_tree_i and min_secondary from the set of values + // of rows_tree_{i+j} + let rows_tree_0_value_range = (U256::MAX / U256::from(2), U256::MAX); + let rows_tree_1_value_range = (U256::MAX / U256::from(4), U256::MAX / U256::from(2)); + let rows_tree_2_value_range = (U256::MAX / U256::from(8), U256::MAX / U256::from(4)); + let rows_tree_3_value_range = (U256::MAX / U256::from(16), U256::MAX / U256::from(8)); + let rows_tree_4_value_range = (U256::ZERO, U256::MAX / U256::from(16)); + let rows_tree_0 = RowsTree::from(generate_test_tree( + secondary_index_id, + Some(rows_tree_0_value_range), + )); + let root = HashOutput::from(rows_tree_0.node_a.compute_node_hash(secondary_index_id)); + let node_0 = build_node(None, None, values[0], root, primary_index_id); + let rows_tree_2 = RowsTree::from(generate_test_tree( + secondary_index_id, + Some(rows_tree_2_value_range), + )); + let root = HashOutput::from(rows_tree_2.node_a.compute_node_hash(secondary_index_id)); + let node_2 = build_node(None, None, values[2], root, primary_index_id); + let rows_tree_4 = RowsTree::from(generate_test_tree( + secondary_index_id, + Some(rows_tree_4_value_range), + )); + let root = HashOutput::from(rows_tree_4.node_a.compute_node_hash(secondary_index_id)); + let node_4 = build_node(None, None, values[4], root, primary_index_id); + let rows_tree_3 = RowsTree::from(generate_test_tree( + secondary_index_id, + Some(rows_tree_3_value_range), + )); + let root = HashOutput::from(rows_tree_3.node_a.compute_node_hash(secondary_index_id)); + let node_3 = build_node( + Some(&node_2), + Some(&node_4), + values[3], + root, + primary_index_id, + ); + let rows_tree_1 = RowsTree::from(generate_test_tree( + secondary_index_id, + Some(rows_tree_1_value_range), + )); + let root = HashOutput::from(rows_tree_1.node_a.compute_node_hash(secondary_index_id)); + let node_1 = build_node( + Some(&node_0), + Some(&node_3), + values[1], + root, + primary_index_id, + ); + + // test consecutive rows in the same rows tree: check that node_C and node_G in rows_tree_1 are consecutive + let path_1c = vec![(rows_tree_1.node_a, ChildPosition::Right)]; + let node_1b_hash = + HashOutput::from(rows_tree_1.node_b.compute_node_hash(secondary_index_id)); + let siblings_1c = vec![Some(node_1b_hash)]; + let merkle_inputs_1c = MerklePathWithNeighborsGadget::::new( + &path_1c, + &siblings_1c, + &rows_tree_1.node_c, + [None, Some(rows_tree_1.node_g)], + ) + .unwrap(); + let path_1g = vec![ + (rows_tree_1.node_c, ChildPosition::Right), + (rows_tree_1.node_a, ChildPosition::Right), + ]; + let siblings_1g = vec![None, Some(node_1b_hash)]; + let merkle_inputs_1g = MerklePathWithNeighborsGadget::::new( + &path_1g, + &siblings_1g, + &rows_tree_1.node_g, + [None, None], + ) + .unwrap(); + let path_1 = vec![]; + let siblings_1 = vec![]; + let merkle_inputs_index_1 = MerklePathWithNeighborsGadget::::new( + &path_1, + &siblings_1, + &node_1, + [Some(node_0), Some(node_3)], + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1c, + first_node_info: rows_tree_1.node_c, + second_node_path: merkle_inputs_1g, + second_node_info: rows_tree_1.node_g, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_1, // they belong to the same node in the index tree + second_node_info: node_1, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test consecutive rows in different rows trees: check that node_G of rows_tree_1 and node_E of rows_tree_2 + // are consecutive + let path_2e = vec![ + (rows_tree_2.node_d, ChildPosition::Left), + (rows_tree_2.node_b, ChildPosition::Left), + (rows_tree_2.node_a, ChildPosition::Left), + ]; + let node_2f_hash = + HashOutput::from(rows_tree_2.node_f.compute_node_hash(secondary_index_id)); + let node_2c_hash = + HashOutput::from(rows_tree_2.node_c.compute_node_hash(secondary_index_id)); + let siblings_2e = vec![Some(node_2f_hash), None, Some(node_2c_hash)]; + let merkle_inputs_2e = MerklePathWithNeighborsGadget::::new( + &path_2e, + &siblings_2e, + &rows_tree_2.node_e, + [None, None], // it's a leaf node + ) + .unwrap(); + let path_2 = vec![ + (node_3, ChildPosition::Left), + (node_1, ChildPosition::Right), + ]; + let node_0_hash = HashOutput::from(node_0.compute_node_hash(primary_index_id)); + let node_4_hash = HashOutput::from(node_4.compute_node_hash(primary_index_id)); + let siblings_2 = vec![Some(node_4_hash), Some(node_0_hash)]; + let merkle_inputs_index_2 = MerklePathWithNeighborsGadget::::new( + &path_2, + &siblings_2, + &node_2, + [None, None], + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_G of rows_tree_1 and node_F of rows_tree_2 are not consecutive + let path_2f = vec![ + (rows_tree_2.node_d, ChildPosition::Right), + (rows_tree_2.node_b, ChildPosition::Left), + (rows_tree_2.node_a, ChildPosition::Left), + ]; + let node_2e_hash = + HashOutput::from(rows_tree_2.node_e.compute_node_hash(secondary_index_id)); + let siblings_2f = vec![Some(node_2e_hash), None, Some(node_2c_hash)]; + let merkle_inputs_2f = MerklePathWithNeighborsGadget::::new( + &path_2f, + &siblings_2f, + &rows_tree_2.node_f, + [None, None], // it's a leaf node + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2f, + second_node_info: rows_tree_2.node_f, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_C of rows_tree_1 and node_E of rows_tree_2 are not consecutive + let path_1c = vec![(rows_tree_1.node_a, ChildPosition::Right)]; + let siblings_1c = vec![Some(node_1b_hash)]; + let merkle_inputs_1c = MerklePathWithNeighborsGadget::::new( + &path_1c, + &siblings_1c, + &rows_tree_1.node_c, + [None, Some(rows_tree_1.node_g)], + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1c, + first_node_info: rows_tree_1.node_c, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_G of rows_tree_1 and node_E of rows_tree_3 are not consecutive + let path_3e = vec![ + (rows_tree_3.node_d, ChildPosition::Left), + (rows_tree_3.node_b, ChildPosition::Left), + (rows_tree_3.node_a, ChildPosition::Left), + ]; + let node_3f_hash = + HashOutput::from(rows_tree_3.node_f.compute_node_hash(secondary_index_id)); + let node_3c_hash = + HashOutput::from(rows_tree_3.node_c.compute_node_hash(secondary_index_id)); + let siblings_3e = vec![Some(node_3f_hash), None, Some(node_3c_hash)]; + let merkle_inputs_3e = MerklePathWithNeighborsGadget::::new( + &path_3e, + &siblings_3e, + &rows_tree_3.node_e, + [None, None], // it's a leaf node + ) + .unwrap(); + let path_3 = vec![(node_1, ChildPosition::Right)]; + let siblings_3 = vec![Some(node_0_hash)]; + let merkle_inputs_index_3 = MerklePathWithNeighborsGadget::::new( + &path_3, + &siblings_3, + &node_3, + [Some(node_2), Some(node_4)], + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_3e, + second_node_info: rows_tree_3.node_e, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_3, + second_node_info: node_3, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // test nodes at range boundaries across different rows trees: check that node_A of rows_tree_1 and node_D + // of rows_tree_2 can be consecutive if the range on secondary index is [node_2D.value, node_1A.value] + let path_1a = vec![]; + let siblings_1a = vec![]; + let merkle_inputs_1a = MerklePathWithNeighborsGadget::::new( + &path_1a, + &siblings_1a, + &rows_tree_1.node_a, + [Some(rows_tree_1.node_b), Some(rows_tree_1.node_c)], + ) + .unwrap(); + let path_2d = vec![ + (rows_tree_2.node_b, ChildPosition::Left), + (rows_tree_2.node_a, ChildPosition::Left), + ]; + let siblings_2d = vec![None, Some(node_2c_hash)]; + let merkle_inputs_2d = MerklePathWithNeighborsGadget::::new( + &path_2d, + &siblings_2d, + &rows_tree_2.node_d, + [Some(rows_tree_2.node_e), Some(rows_tree_2.node_f)], + ) + .unwrap(); + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1a, + first_node_info: rows_tree_1.node_a, + second_node_path: merkle_inputs_2d, + second_node_info: rows_tree_2.node_d, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_2.node_d.value), + max_query_bound: Some(rows_tree_1.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_A of rows_tree_1 and node_D of rows_tree_2 are not be consecutive + // with a different range on secondary index + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1a, + first_node_info: rows_tree_1.node_a, + second_node_path: merkle_inputs_2d, + second_node_info: rows_tree_2.node_d, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_2.node_e.value), + max_query_bound: Some(rows_tree_1.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // test rows tree without matching rows: check that node_A of rows_tree_1 is consecutive with node_G + // of rows_tree_2, if all the nodes in rows_tree_2 store values smaller than the query range + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1a, + first_node_info: rows_tree_1.node_a, + second_node_path: merkle_inputs_2d, + second_node_info: rows_tree_2.node_d, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_1.node_f.value), + max_query_bound: Some(rows_tree_1.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test rows tree without matching rows: check that node_A of rows_tree_1 is consecutive with node_D + // of rows_tree_2, if all the nodes in rows_tree_1 store values bigger than the query range + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1a, + first_node_info: rows_tree_1.node_a, + second_node_path: merkle_inputs_2d, + second_node_info: rows_tree_2.node_d, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_2.node_d.value), + max_query_bound: Some(rows_tree_2.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test rows tree without matching rows: check that we can merge 2 rows in rows trees where all + // the values are smaller than the query range. Node_G of rows_tree_1 is consecutive with node_E of + // rows_tree_2, if the query range is defined over values of rows_tree_0 (which are all bigger than + // other rows trees by construction) + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_0.node_d.value), + max_query_bound: Some(rows_tree_0.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // test rows tree without matching rows: check that we can merge 2 rows in rows trees where all + // the values are bigger than the query range. Node_G of rows_tree_1 is consecutive with node_E of + // rows_tree_2, if the query range is defined over values of rows_tree_4 (which are all smaller than + // the other rows trees by construction) + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: Some(rows_tree_4.node_d.value), + max_query_bound: Some(rows_tree_4.node_a.value), + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + }; + + let proof = run_circuit::(circuit); + // check that the nodes are consecutive + assert!(proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_G of rows_tree_1 and node_E of rows_tree_2 are not consecutive + // if the index tree node storing rows_tree_2 is out of the query range over the primary index + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: None, + max_query_bound: Some(node_1.value), + }, + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + + // negative test: check that node_G of rows_tree_1 and node_E of rows_tree_2 are not consecutive + // if the index tree node storing rows_tree_1 is out of the query range over the primary index + let circuit = TestConsecutiveRows { + row_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_1g, + first_node_info: rows_tree_1.node_g, + second_node_path: merkle_inputs_2e, + second_node_info: rows_tree_2.node_e, + index_id: secondary_index_id, + min_query_bound: None, + max_query_bound: None, + }, + index_tree_nodes: TestConsecutiveNodes { + first_node_path: merkle_inputs_index_1, + first_node_info: node_1, + second_node_path: merkle_inputs_index_2, + second_node_info: node_2, + index_id: primary_index_id, + min_query_bound: Some(node_2.value), + max_query_bound: None, + }, + }; + let proof = run_circuit::(circuit); + // check that the nodes are not consecutive + assert!(!proof.public_inputs[0].try_into_bool().unwrap()); + } +} diff --git a/verifiable-db/src/query/row_chunk_gadgets/mod.rs b/verifiable-db/src/query/row_chunk_gadgets/mod.rs new file mode 100644 index 000000000..08c23d9a0 --- /dev/null +++ b/verifiable-db/src/query/row_chunk_gadgets/mod.rs @@ -0,0 +1,524 @@ +//! This module contains data structures and gadgets employed to build and aggregate +//! row chunks. A row chunk is a set of rows that have already been aggregated +//! and whose rows are all proven to be consecutive. The first and last rows in +//! the chunk are labelled as the `left_boundary_row` and the `right_boundary_row`, +//! respectively, and are the rows employed to aggregate 2 different chunks. + +use alloy::primitives::U256; +use mp2_common::{ + serialization::circuit_data_serialization::SerializableRichField, + utils::{FromFields, FromTargets, HashBuilder, SelectTarget, ToFields, ToTargets}, + F, +}; +use mp2_test::utils::gen_random_field_hash; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + iop::target::{BoolTarget, Target}, + plonk::circuit_builder::CircuitBuilder, +}; +use rand::Rng; + +use crate::{ + query::{ + merkle_path::{MerklePathWithNeighborsTarget, NeighborInfoTarget}, + universal_circuit::universal_query_gadget::UniversalQueryOutputWires, + }, + test_utils::gen_values_in_range, +}; + +use super::{merkle_path::NeighborInfo, utils::QueryBounds}; + +/// This module contains gadgets to aggregate 2 different row chunks +pub(crate) mod aggregate_chunks; +/// This module contains gadgets to enforce whether 2 rows are consecutive +pub(crate) mod consecutive_rows; +/// This module copntains a gadget to prove a single row of the DB +pub(crate) mod row_process_gadget; + +/// Data structure containing the wires representing the data related to the node of +/// the row/index tree containing a row that is on the boundary of a row chunk. +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowNodeInfoTarget { + /// Hash of the node storing the row in the row/index tree + pub(crate) end_node_hash: HashOutTarget, + /// Data about the predecessor of end_node in the row/index tree + pub(crate) predecessor_info: NeighborInfoTarget, + /// Data about the predecessor of end_node in the row/index tree + pub(crate) successor_info: NeighborInfoTarget, +} + +impl<'a, const MAX_DEPTH: usize> From<&'a MerklePathWithNeighborsTarget> + for BoundaryRowNodeInfoTarget +where + [(); MAX_DEPTH - 1]:, +{ + fn from(value: &'a MerklePathWithNeighborsTarget) -> Self { + Self { + end_node_hash: value.end_node_hash, + predecessor_info: value.predecessor_info.clone(), + successor_info: value.successor_info.clone(), + } + } +} + +impl SelectTarget for BoundaryRowNodeInfoTarget { + fn select, const D: usize>( + b: &mut CircuitBuilder, + cond: &BoolTarget, + first: &Self, + second: &Self, + ) -> Self { + Self { + end_node_hash: b.select_hash(*cond, &first.end_node_hash, &second.end_node_hash), + predecessor_info: NeighborInfoTarget::select( + b, + cond, + &first.predecessor_info, + &second.predecessor_info, + ), + successor_info: NeighborInfoTarget::select( + b, + cond, + &first.successor_info, + &second.successor_info, + ), + } + } +} + +impl FromTargets for BoundaryRowNodeInfoTarget { + const NUM_TARGETS: usize = NUM_HASH_OUT_ELTS + 2 * NeighborInfoTarget::NUM_TARGETS; + + fn from_targets(t: &[Target]) -> Self { + assert!(t.len() >= Self::NUM_TARGETS); + Self { + end_node_hash: HashOutTarget::from_vec(t[..NUM_HASH_OUT_ELTS].to_vec()), + predecessor_info: NeighborInfoTarget::from_targets(&t[NUM_HASH_OUT_ELTS..]), + successor_info: NeighborInfoTarget::from_targets( + &t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..], + ), + } + } +} + +impl ToTargets for BoundaryRowNodeInfoTarget { + fn to_targets(&self) -> Vec { + self.end_node_hash + .to_targets() + .into_iter() + .chain(self.predecessor_info.to_targets()) + .chain(self.successor_info.to_targets()) + .collect() + } +} + +/// Data structure containing the `BoundaryRowNodeInfoTarget` wires for the nodes +/// related to a given boundary row. In particular, it contains the +/// `BoundaryRowNodeInfoTarget` related to the following nodes: +/// - `row_node`: the node of the rows tree containing the given boundary row +/// - `index_node`: the node of the index tree that stores the rows tree containing +/// `row_node` +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowDataTarget { + pub(crate) row_node_info: BoundaryRowNodeInfoTarget, + pub(crate) index_node_info: BoundaryRowNodeInfoTarget, +} + +impl FromTargets for BoundaryRowDataTarget { + const NUM_TARGETS: usize = 2 * BoundaryRowNodeInfoTarget::NUM_TARGETS; + fn from_targets(t: &[Target]) -> Self { + assert!(t.len() >= Self::NUM_TARGETS); + Self { + row_node_info: BoundaryRowNodeInfoTarget::from_targets(t), + index_node_info: BoundaryRowNodeInfoTarget::from_targets( + &t[BoundaryRowNodeInfoTarget::NUM_TARGETS..], + ), + } + } +} + +impl ToTargets for BoundaryRowDataTarget { + fn to_targets(&self) -> Vec { + self.row_node_info + .to_targets() + .into_iter() + .chain(self.index_node_info.to_targets()) + .collect() + } +} + +impl SelectTarget for BoundaryRowDataTarget { + fn select, const D: usize>( + b: &mut CircuitBuilder, + cond: &BoolTarget, + first: &Self, + second: &Self, + ) -> Self { + Self { + row_node_info: BoundaryRowNodeInfoTarget::select( + b, + cond, + &first.row_node_info, + &second.row_node_info, + ), + index_node_info: BoundaryRowNodeInfoTarget::select( + b, + cond, + &first.index_node_info, + &second.index_node_info, + ), + } + } +} + +/// Data structure containing the wires associated to a given row chunk +#[derive(Clone, Debug)] +pub(crate) struct RowChunkDataTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) left_boundary_row: BoundaryRowDataTarget, + pub(crate) right_boundary_row: BoundaryRowDataTarget, + pub(crate) chunk_outputs: UniversalQueryOutputWires, +} + +impl FromTargets for RowChunkDataTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + const NUM_TARGETS: usize = + 2 * BoundaryRowDataTarget::NUM_TARGETS + UniversalQueryOutputWires::NUM_TARGETS; + + fn from_targets(t: &[Target]) -> Self { + assert!(t.len() >= Self::NUM_TARGETS); + Self { + left_boundary_row: BoundaryRowDataTarget::from_targets(t), + right_boundary_row: BoundaryRowDataTarget::from_targets( + &t[BoundaryRowDataTarget::NUM_TARGETS..], + ), + chunk_outputs: UniversalQueryOutputWires::from_targets( + &t[2 * BoundaryRowDataTarget::NUM_TARGETS..], + ), + } + } +} + +impl ToTargets for RowChunkDataTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + fn to_targets(&self) -> Vec { + self.left_boundary_row + .to_targets() + .into_iter() + .chain(self.right_boundary_row.to_targets()) + .chain(self.chunk_outputs.to_targets()) + .collect() + } +} + +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowNodeInfo { + pub(crate) end_node_hash: HashOut, + pub(crate) predecessor_info: NeighborInfo, + pub(crate) successor_info: NeighborInfo, +} + +impl ToFields for BoundaryRowNodeInfo { + fn to_fields(&self) -> Vec { + self.end_node_hash + .to_fields() + .into_iter() + .chain(self.predecessor_info.to_fields()) + .chain(self.successor_info.to_fields()) + .collect() + } +} + +impl FromFields for BoundaryRowNodeInfo { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= BoundaryRowNodeInfoTarget::NUM_TARGETS); + let end_node_hash = HashOut::from_partial(&t[..NUM_HASH_OUT_ELTS]); + let predecessor_info = NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS..]); + let successor_info = + NeighborInfo::from_fields(&t[NUM_HASH_OUT_ELTS + NeighborInfoTarget::NUM_TARGETS..]); + + Self { + end_node_hash, + predecessor_info, + successor_info, + } + } +} + +impl BoundaryRowNodeInfo { + /// Generate an instance of `Self` representing a random node, given the `query_bounds` + /// provided as input and a flag `is_index_tree` specifying whether the random node + /// should be part of an index tree or of a rows tree. It is used to generate test data + /// without the need to generate an actual tree + pub(crate) fn sample( + rng: &mut R, + query_bounds: &QueryBounds, + is_index_tree: bool, + ) -> Self { + let (min_query_bound, max_query_bound) = if is_index_tree { + ( + query_bounds.min_query_primary(), + query_bounds.max_query_primary(), + ) + } else { + ( + *query_bounds.min_query_secondary().value(), + *query_bounds.max_query_secondary().value(), + ) + }; + let end_node_hash = gen_random_field_hash(); + let [predecessor_value] = gen_values_in_range( + rng, + if is_index_tree { + min_query_bound // predecessor in index tree must always be in range + } else { + U256::ZERO + }, + max_query_bound, // predecessor value must always be smaller than max_secondary in circuit + ); + let predecessor_info = NeighborInfo::sample( + rng, + predecessor_value, + if is_index_tree { + // in index tree, there must always be a predecessor for boundary rows + Some(true) + } else { + None + }, + ); + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(min_query_bound), // successor value must + // always be greater than min_secondary in circuit, and it must be also + // greater than predecessor value since we are in a BST + if is_index_tree { + max_query_bound // successor in index tree must always be in range + } else { + U256::MAX + }, + ); + let successor_info = NeighborInfo::sample( + rng, + successor_value, + if is_index_tree { + // in index tree, there must always be a successor for boundary rows + Some(true) + } else { + None + }, + ); + + Self { + end_node_hash, + predecessor_info, + successor_info, + } + } + + /// Given a boundary node with info stored in `self`, this method generates at random the + /// information about a node that can be the successor of `self` in a BST. This method + /// requires as additional inputs the `query_bounds` and a flag `is_index_tree`, which + /// specifies whether `self` and the generated node should be part of an index tree or + /// of a rows tree + pub(crate) fn sample_successor_in_tree( + &self, + rng: &mut R, + query_bounds: &QueryBounds, + is_index_tree: bool, + ) -> Self { + let (min_query_bound, max_query_bound) = if is_index_tree { + ( + query_bounds.min_query_primary(), + query_bounds.max_query_primary(), + ) + } else { + ( + *query_bounds.min_query_secondary().value(), + *query_bounds.max_query_secondary().value(), + ) + }; + let end_node_hash = self.successor_info.hash; + // value of predecessor must be in query range and between the predecessor and successor value + // of `self` + let [predecessor_value] = gen_values_in_range( + rng, + min_query_bound.max(self.predecessor_info.value), + self.successor_info.value.min(max_query_bound), + ); + let predecessor_info = if self.successor_info.is_in_path { + NeighborInfo::new(predecessor_value, None) + } else { + NeighborInfo::new(predecessor_value, Some(self.end_node_hash)) + }; + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(min_query_bound), + if is_index_tree { + max_query_bound // successor must always be in range in index tree + } else { + U256::MAX + }, + ); + let successor_info = NeighborInfo::sample( + rng, + successor_value, + if is_index_tree { + // in index tree, there must always be a successor for boundary rows + Some(true) + } else { + None + }, + ); + BoundaryRowNodeInfo { + end_node_hash, + predecessor_info, + successor_info, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct BoundaryRowData { + pub(crate) row_node_info: BoundaryRowNodeInfo, + pub(crate) index_node_info: BoundaryRowNodeInfo, +} + +impl ToFields for BoundaryRowData { + fn to_fields(&self) -> Vec { + self.row_node_info + .to_fields() + .into_iter() + .chain(self.index_node_info.to_fields()) + .collect() + } +} + +impl FromFields for BoundaryRowData { + fn from_fields(t: &[F]) -> Self { + assert!(t.len() >= BoundaryRowDataTarget::NUM_TARGETS); + let row_node_info = BoundaryRowNodeInfo::from_fields(t); + let index_node_info = + BoundaryRowNodeInfo::from_fields(&t[BoundaryRowNodeInfoTarget::NUM_TARGETS..]); + + Self { + row_node_info, + index_node_info, + } + } +} + +impl BoundaryRowData { + /// Generate a random instance of `Self`, given the `query_bounds` provided as inputs. + /// It is employed to generate test data without the need to build an actual test tree + pub(crate) fn sample(rng: &mut R, query_bounds: &QueryBounds) -> Self { + Self { + row_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, false), + index_node_info: BoundaryRowNodeInfo::sample(rng, query_bounds, true), + } + } + + /// Given the boundary row `self`, generates at random the data of the consecutive row of + /// `self`, given the `query_bounds` provided as input. It is employed to generate test data + /// without the need to build an actual test tree + pub(crate) fn sample_consecutive_row( + &self, + rng: &mut R, + query_bounds: &QueryBounds, + ) -> Self { + if self.row_node_info.successor_info.is_found + && self.row_node_info.successor_info.value + <= *query_bounds.max_query_secondary().value() + { + // the successor must be in the same rows tree + let row_node_info = + self.row_node_info + .sample_successor_in_tree(rng, query_bounds, false); + Self { + row_node_info, + index_node_info: self.index_node_info.clone(), + } + } else { + // the successor must be in a different rows tree + let end_node_hash = gen_random_field_hash(); + // predecessor value must be out of range in this case + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds + .min_query_secondary() + .value() + .checked_sub(U256::from(1)) + .unwrap_or(U256::ZERO), + ); + let predecessor_info = NeighborInfo::sample(rng, predecessor_value, None); + let [successor_value] = gen_values_in_range( + rng, + predecessor_value.max(*query_bounds.min_query_secondary().value()), // successor value must + // always be greater than min_secondary in circuit + U256::MAX, + ); + let successor_info = NeighborInfo::sample(rng, successor_value, None); + let row_node_info = BoundaryRowNodeInfo { + end_node_hash, + predecessor_info, + successor_info, + }; + // index tree node must be a successor of `self.index_node` + let index_node_info = + self.index_node_info + .sample_successor_in_tree(rng, query_bounds, true); + Self { + row_node_info, + index_node_info, + } + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use mp2_common::{utils::ToFields, F}; + use plonky2::{field::types::Field, hash::hash_types::HashOut}; + + use crate::query::universal_circuit::universal_query_gadget::OutputValues; + + use super::BoundaryRowData; + + #[derive(Clone, Debug)] + pub(crate) struct RowChunkData + where + [(); MAX_NUM_RESULTS - 1]:, + { + pub(crate) left_boundary_row: BoundaryRowData, + pub(crate) right_boundary_row: BoundaryRowData, + pub(crate) chunk_tree_hash: HashOut, + pub(crate) output_values: OutputValues, + pub(crate) num_overflows: u64, + pub(crate) count: u64, + } + + impl ToFields for RowChunkData + where + [(); MAX_NUM_RESULTS - 1]:, + { + fn to_fields(&self) -> Vec { + self.left_boundary_row + .to_fields() + .into_iter() + .chain(self.right_boundary_row.to_fields()) + .chain(self.chunk_tree_hash.to_fields()) + .chain(self.output_values.to_fields()) + .chain([ + F::from_canonical_u64(self.count), + F::from_canonical_u64(self.num_overflows), + ]) + .collect() + } + } +} diff --git a/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs b/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs new file mode 100644 index 000000000..30b9c84b6 --- /dev/null +++ b/verifiable-db/src/query/row_chunk_gadgets/row_process_gadget.rs @@ -0,0 +1,325 @@ +use anyhow::Result; +use std::array; + +use mp2_common::{types::CBuilder, u256::UInt256Target, F}; +use plonky2::iop::witness::PartialWitness; +use serde::{Deserialize, Serialize}; + +use crate::query::{ + api::RowInput, + merkle_path::{MerklePathWithNeighborsGadget, MerklePathWithNeighborsTargetInputs}, + universal_circuit::{ + universal_circuit_inputs::RowCells, + universal_query_gadget::{ + OutputComponent, UniversalQueryHashInputWires, UniversalQueryValueInputWires, + UniversalQueryValueInputs, UniversalQueryValueWires, + }, + }, +}; + +use super::{BoundaryRowDataTarget, BoundaryRowNodeInfoTarget, RowChunkDataTarget}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct RowProcessingGadgetInputWires< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + pub(crate) row_path: MerklePathWithNeighborsTargetInputs, + pub(crate) index_path: MerklePathWithNeighborsTargetInputs, + pub(crate) input_values: UniversalQueryValueInputWires, +} + +impl< + 'a, + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_RESULTS: usize, + > + From< + &'a RowProcessingGadgetWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_RESULTS, + >, + > for RowProcessingGadgetInputWires +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, +{ + fn from( + value: &'a RowProcessingGadgetWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_RESULTS, + >, + ) -> Self { + RowProcessingGadgetInputWires { + row_path: value.row_path.clone(), + index_path: value.index_path.clone(), + input_values: value.value_wires.input_wires.clone(), + } + } +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] // only in this PR +pub(crate) struct RowProcessingGadgetWires< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_RESULTS: usize, +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) row_path: MerklePathWithNeighborsTargetInputs, + pub(crate) row_node_data: BoundaryRowNodeInfoTarget, + pub(crate) index_path: MerklePathWithNeighborsTargetInputs, + pub(crate) index_node_data: BoundaryRowNodeInfoTarget, + pub(crate) value_wires: UniversalQueryValueWires, +} + +impl< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_RESULTS: usize, + > + From< + RowProcessingGadgetWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_RESULTS, + >, + > for RowChunkDataTarget +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, +{ + fn from( + value: RowProcessingGadgetWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_RESULTS, + >, + ) -> Self { + RowChunkDataTarget { + left_boundary_row: BoundaryRowDataTarget { + row_node_info: value.row_node_data.clone(), + index_node_info: value.index_node_data.clone(), + }, + right_boundary_row: BoundaryRowDataTarget { + row_node_info: value.row_node_data, + index_node_info: value.index_node_data, + }, + chunk_outputs: value.value_wires.output_wires, + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct RowProcessingGadgetInputs< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, +> where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, +{ + row_path: MerklePathWithNeighborsGadget, + index_path: MerklePathWithNeighborsGadget, + input_values: UniversalQueryValueInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >, +} + +impl< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + > + RowProcessingGadgetInputs< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) fn new( + row_path: MerklePathWithNeighborsGadget, + index_path: MerklePathWithNeighborsGadget, + row_cells: &RowCells, + ) -> Result { + Ok(Self { + row_path, + index_path, + input_values: UniversalQueryValueInputs::new(row_cells, false)?, + }) + } + + #[allow(dead_code)] // unused for now, but could be a useful method + pub(crate) fn new_dummy_row( + row_path: MerklePathWithNeighborsGadget, + index_path: MerklePathWithNeighborsGadget, + row_cells: &RowCells, + ) -> Result { + Ok(Self { + row_path, + index_path, + input_values: UniversalQueryValueInputs::new(row_cells, true)?, + }) + } + + pub(crate) fn clone_to_dummy_row(&self) -> Self { + let mut input_values = self.input_values.clone(); + input_values.is_dummy_row = true; + Self { + row_path: self.row_path, + index_path: self.index_path, + input_values, + } + } + + pub(crate) fn build>( + b: &mut CBuilder, + hash_input_wires: &UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + min_query_secondary: &UInt256Target, + max_query_secondary: &UInt256Target, + ) -> RowProcessingGadgetWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_RESULTS, + > { + let zero = b.zero(); + let mut value_wires = UniversalQueryValueInputs::build( + b, + hash_input_wires, + min_query_secondary, + max_query_secondary, + &zero, + ); + let [primary_index_id, secondary_index_id] = + array::from_fn(|i| hash_input_wires.column_extraction_wires.column_ids[i]); + let [primary_index_value, secondary_index_value] = + array::from_fn(|i| value_wires.input_wires.column_values[i].clone()); + let row_path = MerklePathWithNeighborsGadget::build( + b, + secondary_index_value, + value_wires.output_wires.tree_hash, // hash of the cells tree stored + // in the row node must be the one computed by universal query gadget + secondary_index_id, + ); + let index_path = MerklePathWithNeighborsGadget::build( + b, + primary_index_value, + row_path.root, // computed root of row tree must be the same as the root of + // the subtree stored in `index_node` + primary_index_id, + ); + + // the tree hash in output values for the current row must correspond to the index tree hash + value_wires.output_wires.tree_hash = index_path.root; + + let row_node_data = BoundaryRowNodeInfoTarget::from(&row_path); + let index_node_data = BoundaryRowNodeInfoTarget::from(&index_path); + RowProcessingGadgetWires { + row_path: row_path.inputs, + row_node_data, + index_path: index_path.inputs, + index_node_data, + value_wires, + } + } + + #[allow(dead_code)] // only in this PR + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &RowProcessingGadgetInputWires< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + >, + ) { + self.row_path.assign(pw, &wires.row_path); + self.index_path.assign(pw, &wires.index_path); + self.input_values.assign(pw, &wires.input_values); + } +} + +impl< + const ROW_TREE_MAX_DEPTH: usize, + const INDEX_TREE_MAX_DEPTH: usize, + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + > TryFrom<&RowInput> + for RowProcessingGadgetInputs< + ROW_TREE_MAX_DEPTH, + INDEX_TREE_MAX_DEPTH, + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > +where + [(); ROW_TREE_MAX_DEPTH - 1]:, + [(); INDEX_TREE_MAX_DEPTH - 1]:, + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + fn try_from(value: &RowInput) -> Result { + let index_path = MerklePathWithNeighborsGadget::new( + &value.path.index_tree_path.path, + &value.path.index_tree_path.siblings, + &value.path.index_tree_path.node_info, + value.path.index_tree_path.children, + )?; + let row_path = MerklePathWithNeighborsGadget::new( + &value.path.row_tree_path.path, + &value.path.row_tree_path.siblings, + &value.path.row_tree_path.node_info, + value.path.row_tree_path.children, + )?; + + Self::new(row_path, index_path, &value.cells) + } + + type Error = anyhow::Error; +} diff --git a/verifiable-db/src/query/universal_circuit/basic_operation.rs b/verifiable-db/src/query/universal_circuit/basic_operation.rs index 6d85e4e20..793a062cd 100644 --- a/verifiable-db/src/query/universal_circuit/basic_operation.rs +++ b/verifiable-db/src/query/universal_circuit/basic_operation.rs @@ -1,3 +1,5 @@ +use std::iter::once; + use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ @@ -38,13 +40,21 @@ pub struct BasicOperationInputWires { /// by this instance of the component, among all the supported operations op_selector: Target, } +/// Output wires for the operation performed by the basic operation component +pub(crate) struct BasicOperationValueWires { + pub(crate) output_value: UInt256Target, + pub(crate) num_overflows: Target, +} +/// Input + output wires for the computational hash of basic operation component +pub(crate) struct BasicOperationHashWires { + pub(crate) input_wires: BasicOperationInputWires, + pub(crate) output_hash: ComputationalHashTarget, +} /// Input + output wires for basic operation component pub struct BasicOperationWires { - pub(crate) input_wires: BasicOperationInputWires, - pub(crate) output_value: UInt256Target, - pub(crate) output_hash: ComputationalHashTarget, - pub(crate) num_overflows: Target, + pub(crate) value_wires: BasicOperationValueWires, + pub(crate) hash_wires: BasicOperationHashWires, } /// Witness input values for basic operation component #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -81,34 +91,28 @@ impl BasicOperationInputs { num_inputs_values + 2 } - pub(crate) fn build( + pub(crate) fn build_values( b: &mut CircuitBuilder, input_values: &[UInt256Target], - input_hash: &[ComputationalHashTarget], + input_wires: &BasicOperationInputWires, num_overflows: Target, - ) -> BasicOperationWires { + ) -> BasicOperationValueWires { let zero = b.zero(); - let additional_operands = (0..3) - .map( - |_| b.add_virtual_u256_unsafe(), // should be ok to use `unsafe` here since these values are directly hashed in computational hash or in placeholder hash - ) - .collect_vec(); - let constant_operand = &additional_operands[0]; - let placeholder_values = &additional_operands[1..]; let possible_input_values = input_values .iter() - .chain(additional_operands.iter()) + .chain(once(&input_wires.constant_operand)) + .chain(input_wires.placeholder_values.iter()) .cloned() .collect_vec(); - let first_input_selector = b.add_virtual_target(); - let second_input_selector = b.add_virtual_target(); - let placeholder_ids = b.add_virtual_target_arr::<2>(); - let op_selector = b.add_virtual_target(); //TODO: these 2 random accesses could be done with a single operation, if we add an ad-hoc gate - let first_input = - b.random_access_u256(first_input_selector, possible_input_values.as_slice()); - let second_input = - b.random_access_u256(second_input_selector, possible_input_values.as_slice()); + let first_input = b.random_access_u256( + input_wires.first_input_selector, + possible_input_values.as_slice(), + ); + let second_input = b.random_access_u256( + input_wires.second_input_selector, + possible_input_values.as_slice(), + ); // compute results for all the operations @@ -122,8 +126,8 @@ impl BasicOperationInputs { // Given the `op_selector` for the actual operation, we compute // `prod = (op_selector-div_selector)*(op_selector-mod_selector)`. // Then, the operation is division or modulo iff `prod == 0`` - let div_diff = b.sub(op_selector, div_selector); - let mod_diff = b.sub(op_selector, mod_selector); + let div_diff = b.sub(input_wires.op_selector, div_selector); + let mod_diff = b.sub(input_wires.op_selector, mod_selector); let prod = b.mul(div_diff, mod_diff); b.is_equal(prod, zero) }; @@ -197,24 +201,37 @@ impl BasicOperationInputs { // choose the proper output values and overflows error occurred depending on the // operation to be performed in the current instance of basic operation component - let output_value = b.random_access_u256(op_selector, &possible_output_values); + let output_value = b.random_access_u256(input_wires.op_selector, &possible_output_values); assert!( possible_overflows_occurred.len() <= 64, "random access gadget works only for arrays with at most 64 elements" ); - let overflows_occurred = b.random_access(op_selector, possible_overflows_occurred); + let overflows_occurred = + b.random_access(input_wires.op_selector, possible_overflows_occurred); - // compute computational hash associated to the operation being computed - let output_hash = Operation::basic_operation_hash_circuit( - b, - input_hash, - constant_operand, - placeholder_ids, - first_input_selector, - second_input_selector, - op_selector, - ); + BasicOperationValueWires { + output_value, + num_overflows: b.add(num_overflows, overflows_occurred), + } + } + + pub(crate) fn build_hash( + b: &mut CircuitBuilder, + input_hash: &[ComputationalHashTarget], + ) -> BasicOperationHashWires { + let additional_operands = (0..3) + .map( + |_| b.add_virtual_u256_unsafe(), // should be ok to use `unsafe` here since these values are directly hashed in computational hash or in placeholder hash + ) + .collect_vec(); + let constant_operand = &additional_operands[0]; + let placeholder_values = &additional_operands[1..]; + + let first_input_selector = b.add_virtual_target(); + let second_input_selector = b.add_virtual_target(); + let placeholder_ids = b.add_virtual_target_arr::<2>(); + let op_selector = b.add_virtual_target(); let input_wires = BasicOperationInputWires { constant_operand: constant_operand.clone(), @@ -225,11 +242,37 @@ impl BasicOperationInputs { op_selector, }; - BasicOperationWires { + // compute computational hash associated to the operation being computed + let output_hash = Operation::basic_operation_hash_circuit( + b, + input_hash, + &input_wires.constant_operand, + input_wires.placeholder_ids, + input_wires.first_input_selector, + input_wires.second_input_selector, + input_wires.op_selector, + ); + + BasicOperationHashWires { input_wires, - output_value, output_hash, - num_overflows: b.add(num_overflows, overflows_occurred), + } + } + + pub(crate) fn build( + b: &mut CircuitBuilder, + input_values: &[UInt256Target], + input_hash: &[ComputationalHashTarget], + num_overflows: Target, + ) -> BasicOperationWires { + let hash_wires = Self::build_hash(b, input_hash); + + let value_wires = + Self::build_values(b, input_values, &hash_wires.input_wires, num_overflows); + + BasicOperationWires { + value_wires, + hash_wires, } } @@ -315,14 +358,14 @@ mod tests { let expected_hash = c.add_virtual_hash(); let num_errors = c.add_virtual_target(); - c.enforce_equal_u256(&expected_result, &wires.output_value); - c.connect_hashes(expected_hash, wires.output_hash); - c.connect(wires.num_overflows, num_errors); + c.enforce_equal_u256(&expected_result, &wires.value_wires.output_value); + c.connect_hashes(expected_hash, wires.hash_wires.output_hash); + c.connect(wires.value_wires.num_overflows, num_errors); Self::Wires { input_values, input_hash: input_hash.try_into().unwrap(), - component_wires: wires.input_wires, + component_wires: wires.hash_wires.input_wires, expected_result, expected_hash, num_errors, diff --git a/verifiable-db/src/query/universal_circuit/cells.rs b/verifiable-db/src/query/universal_circuit/cells.rs index 090cfe1e8..f57e5b04f 100644 --- a/verifiable-db/src/query/universal_circuit/cells.rs +++ b/verifiable-db/src/query/universal_circuit/cells.rs @@ -4,7 +4,7 @@ use mp2_common::{ poseidon::empty_poseidon_hash, types::CBuilder, u256::UInt256Target, - utils::{SelectHashBuilder, ToTargets}, + utils::{HashBuilder, ToTargets}, CHasher, }; use plonky2::{ diff --git a/verifiable-db/src/query/universal_circuit/column_extraction.rs b/verifiable-db/src/query/universal_circuit/column_extraction.rs index 439b2a3c3..9dbb433db 100644 --- a/verifiable-db/src/query/universal_circuit/column_extraction.rs +++ b/verifiable-db/src/query/universal_circuit/column_extraction.rs @@ -1,14 +1,13 @@ use super::{cells::build_cells_tree, ComputationalHashTarget, MembershipHashTarget}; use crate::query::computational_hash_ids::{Extraction, Identifiers}; -use alloy::primitives::U256; use mp2_common::{ poseidon::empty_poseidon_hash, serialization::{ deserialize_array, deserialize_long_array, serialize_array, serialize_long_array, }, types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::SelectHashBuilder, + u256::{CircuitBuilderU256, UInt256Target}, + utils::HashBuilder, F, }; use plonky2::iop::{ @@ -21,12 +20,6 @@ use std::array; /// Input wires for the column extraction component #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct ColumnExtractionInputWires { - /// values of the columns for the current row - #[serde( - serialize_with = "serialize_array", - deserialize_with = "deserialize_array" - )] - pub(crate) column_values: [UInt256Target; MAX_NUM_COLUMNS], /// integer identifier associated to each of the columns #[serde( serialize_with = "serialize_array", @@ -42,14 +35,24 @@ pub struct ColumnExtractionInputWires { )] is_real_column: [BoolTarget; MAX_NUM_COLUMNS], } -/// Input + output wires for the column extraction component -pub(crate) struct ColumnExtractionWires { - pub(crate) input_wires: ColumnExtractionInputWires, +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct ColumnExtractionValueWires { /// Hash of the cells tree pub(crate) tree_hash: MembershipHashTarget, +} + +pub(crate) struct ColumnExtractionHashWires { + pub(crate) input_wires: ColumnExtractionInputWires, /// Computational hash associated to the extraction of each of the `MAX_NUM_COLUMNS` columns pub(crate) column_hash: [ComputationalHashTarget; MAX_NUM_COLUMNS], } + +/// Input + output wires for the column extraction component +#[cfg(test)] // used only in test for now +pub(crate) struct ColumnExtractionWires { + pub(crate) value_wires: ColumnExtractionValueWires, + pub(crate) hash_wires: ColumnExtractionHashWires, +} /// Witness input values for column extraction component #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ColumnExtractionInputs { @@ -58,19 +61,31 @@ pub struct ColumnExtractionInputs { serialize_with = "serialize_long_array", deserialize_with = "deserialize_long_array" )] - pub(crate) column_values: [U256; MAX_NUM_COLUMNS], - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] pub(crate) column_ids: [F; MAX_NUM_COLUMNS], } impl ColumnExtractionInputs { - pub(crate) fn build(b: &mut CBuilder) -> ColumnExtractionWires { + pub(crate) fn build_column_values(b: &mut CBuilder) -> [UInt256Target; MAX_NUM_COLUMNS] { + b.add_virtual_u256_arr_unsafe() + } + + pub(crate) fn build_tree_hash( + b: &mut CBuilder, + column_values: &[UInt256Target; MAX_NUM_COLUMNS], + input_wires: &ColumnExtractionInputWires, + ) -> ColumnExtractionValueWires { + // Exclude the first 2 indexed columns to build the cells tree. + let input_values = &column_values[2..]; + let input_ids = &input_wires.column_ids[2..]; + let is_real_value = &input_wires.is_real_column[2..]; + let tree_hash = build_cells_tree(b, input_values, input_ids, is_real_value); + + ColumnExtractionValueWires { tree_hash } + } + + pub(crate) fn build_hash(b: &mut CBuilder) -> ColumnExtractionHashWires { // Initialize the input wires. let input_wires = ColumnExtractionInputWires { - column_values: [0; MAX_NUM_COLUMNS].map(|_| b.add_virtual_u256_unsafe()), // should be ok to use unsafe since these values are directly hashed to compute tree hash column_ids: b.add_virtual_target_arr(), is_real_column: [0; MAX_NUM_COLUMNS].map(|_| b.add_virtual_bool_target_safe()), }; @@ -78,16 +93,23 @@ impl ColumnExtractionInputs { // Build the column hashes by the input. let column_hash = build_column_hash(b, &input_wires); - // Exclude the first 2 indexed columns to build the cells tree. - let input_values = &input_wires.column_values[2..]; - let input_ids = &input_wires.column_ids[2..]; - let is_real_value = &input_wires.is_real_column[2..]; - let tree_hash = build_cells_tree(b, input_values, input_ids, is_real_value); + ColumnExtractionHashWires { + input_wires, + column_hash, + } + } + + #[cfg(test)] // used only in test for now + pub(crate) fn build( + b: &mut CBuilder, + column_values: &[UInt256Target; MAX_NUM_COLUMNS], + ) -> ColumnExtractionWires { + let hash_wires = Self::build_hash(b); + let value_wires = Self::build_tree_hash(b, column_values, &hash_wires.input_wires); ColumnExtractionWires { - tree_hash, - column_hash, - input_wires, + value_wires, + hash_wires, } } @@ -96,10 +118,6 @@ impl ColumnExtractionInputs { pw: &mut PartialWitness, wires: &ColumnExtractionInputWires, ) { - self.column_values - .iter() - .zip(wires.column_values.iter()) - .for_each(|(v, t)| pw.set_u256_target(t, *v)); pw.set_target_arr(wires.column_ids.as_slice(), self.column_ids.as_slice()); wires .is_real_column @@ -134,7 +152,8 @@ mod tests { use crate::query::universal_circuit::{ComputationalHash, MembershipHash}; use super::*; - use mp2_common::{C, D}; + use alloy::primitives::U256; + use mp2_common::{u256::WitnessWriteU256, C, D}; use mp2_test::{ cells_tree::{compute_cells_tree_hash, TestCell}, circuit::{run_circuit, UserCircuit}, @@ -144,6 +163,7 @@ mod tests { #[derive(Clone, Debug)] struct TestColumnExtractionCircuit { inputs: ColumnExtractionInputs, + column_values: [U256; MAX_NUM_COLUMNS], column_hash: [ComputationalHash; MAX_NUM_COLUMNS], tree_hash: MembershipHash, } @@ -152,40 +172,45 @@ mod tests { for TestColumnExtractionCircuit { // Column extraction wires + // + column values // + expected output column hash // + expected output tree hash type Wires = ( ColumnExtractionWires, + [UInt256Target; MAX_NUM_COLUMNS], [ComputationalHashTarget; MAX_NUM_COLUMNS], MembershipHashTarget, ); fn build(b: &mut CBuilder) -> Self::Wires { - let wires = ColumnExtractionInputs::build(b); + let column_values = ColumnExtractionInputs::build_column_values(b); + let wires = ColumnExtractionInputs::build(b, &column_values); let column_hash = array::from_fn(|_| b.add_virtual_hash()); let tree_hash = b.add_virtual_hash(); // Check the output column hash. wires + .hash_wires .column_hash .iter() .zip(column_hash) .for_each(|(l, r)| b.connect_hashes(*l, r)); // Check the output tree hash. - b.connect_hashes(wires.tree_hash, tree_hash); + b.connect_hashes(wires.value_wires.tree_hash, tree_hash); - (wires, column_hash, tree_hash) + (wires, column_values, column_hash, tree_hash) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.inputs.assign(pw, &wires.0.input_wires); + pw.set_u256_target_arr(&wires.1, &self.column_values); + self.inputs.assign(pw, &wires.0.hash_wires.input_wires); wires - .1 + .2 .iter() .zip(self.column_hash) .for_each(|(t, v)| pw.set_hash_target(*t, v)); - pw.set_hash_target(wires.2, self.tree_hash); + pw.set_hash_target(wires.3, self.tree_hash); } } @@ -207,12 +232,12 @@ mod tests { let column_values = column_values.try_into().unwrap(); let inputs = ColumnExtractionInputs { real_num_columns, - column_values, column_ids, }; Self { inputs, + column_values, column_hash, tree_hash, } diff --git a/verifiable-db/src/query/universal_circuit/mod.rs b/verifiable-db/src/query/universal_circuit/mod.rs index e20622708..6ac6d1d68 100644 --- a/verifiable-db/src/query/universal_circuit/mod.rs +++ b/verifiable-db/src/query/universal_circuit/mod.rs @@ -19,6 +19,8 @@ pub(crate) mod output_with_aggregation; /// https://www.notion.so/lagrangelabs/Queries-Circuits-2695199166a54954bbc44ad9dc398825?pvs=4#5c0d5af8c40f4bf0ae7dd13b20a54dcc /// while the detailed specs can be found here https://www.notion.so/lagrangelabs/Queries-Circuits-2695199166a54954bbc44ad9dc398825?pvs=4#22fbb552e11e411e95d426264c94aa46 pub mod universal_query_circuit; +/// Gadget to process a single row in the DB according to a specific query +pub(crate) mod universal_query_gadget; /// Set of data structures to be provided as input to initialize a universal query circuit to prove /// the query computation for a single row. They basically allow to represent in a strucutred format diff --git a/verifiable-db/src/query/universal_circuit/output_no_aggregation.rs b/verifiable-db/src/query/universal_circuit/output_no_aggregation.rs index 7e51bef17..611b8c175 100644 --- a/verifiable-db/src/query/universal_circuit/output_no_aggregation.rs +++ b/verifiable-db/src/query/universal_circuit/output_no_aggregation.rs @@ -25,7 +25,9 @@ use std::{ use super::{ cells::build_cells_tree, - universal_query_circuit::{OutputComponent, OutputComponentWires}, + universal_query_gadget::{ + OutputComponent, OutputComponentHashWires, OutputComponentValueWires, + }, ComputationalHashTarget, }; @@ -56,20 +58,26 @@ pub struct InputWires { is_output_valid: [BoolTarget; MAX_NUM_RESULTS], } -/// Input + output wires for output component for queries without results aggregation -pub struct Wires { +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct HashWires { /// input wires of the component input_wires: InputWires, + + /// Computational hash representing all the computation done in the query circuit + output_hash: ComputationalHashTarget, + /// Identifiers of the aggregation operations to be returned as public inputs + ops_ids: [Target; MAX_NUM_RESULTS], +} + +#[derive(Clone, Debug)] +pub struct ValueWires { /// The first output value computed by this component; it is a `CurveTarget` since /// it corresponds to the accumulator of all the results of the query first_output_value: CurveTarget, /// Remaining output values; for this component, they are basically dummy values output_values: Vec, - /// Computational hash representing all the computation done in the query circuit - output_hash: ComputationalHashTarget, - /// Identifiers of the aggregation operations to be returned as public inputs - ops_ids: [Target; MAX_NUM_RESULTS], } + /// Witness input values for output component for queries without results aggregation #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Circuit { @@ -86,21 +94,23 @@ pub struct Circuit { ids: [F; MAX_NUM_RESULTS], } -impl OutputComponentWires for Wires { +impl OutputComponentValueWires for ValueWires { type FirstT = CurveTarget; - type InputWires = InputWires; - - fn ops_ids(&self) -> &[Target] { - self.ops_ids.as_slice() - } - fn first_output_value(&self) -> Self::FirstT { self.first_output_value } fn other_output_values(&self) -> &[UInt256Target] { - self.output_values.as_slice() + &self.output_values + } +} + +impl OutputComponentHashWires for HashWires { + type InputWires = InputWires; + + fn ops_ids(&self) -> &[Target] { + self.ops_ids.as_slice() } fn computational_hash(&self) -> ComputationalHashTarget { @@ -113,25 +123,56 @@ impl OutputComponentWires for Wires OutputComponent for Circuit { - type Wires = Wires; + type HashWires = HashWires; + type ValueWires = ValueWires; - fn build( + fn assign(&self, pw: &mut PartialWitness, wires: &InputWires) { + pw.set_target_arr(wires.selector.as_slice(), self.selector.as_slice()); + pw.set_target_arr(wires.ids.as_slice(), self.ids.as_slice()); + wires + .is_output_valid + .iter() + .enumerate() + .for_each(|(i, t)| pw.set_bool_target(*t, i < self.valid_num_outputs)); + } + + fn new(selector: &[F], ids: &[F], num_outputs: usize) -> anyhow::Result { + ensure!(selector.len() == num_outputs, + "Output component without aggregation: Number of selectors different from number of actual outputs"); + ensure!(ids.len() == num_outputs, + "Output component without aggregation: Number of output ids different from number of actual outputs"); + let selectors = selector + .iter() + .chain(repeat(&F::default())) + .take(MAX_NUM_RESULTS) + .cloned() + .collect_vec(); + let output_ids = ids + .iter() + .chain(repeat(&F::default())) + .take(MAX_NUM_RESULTS) + .cloned() + .collect_vec(); + Ok(Self { + valid_num_outputs: num_outputs, + selector: selectors.try_into().unwrap(), + ids: output_ids.try_into().unwrap(), + }) + } + + fn output_variant() -> Output { + Output::NoAggregation + } + + fn build_values( b: &mut CBuilder, possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], - possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], predicate_value: &BoolTarget, - predicate_hash: &ComputationalHashTarget, - ) -> Self::Wires { + input_wires: &::InputWires, + ) -> Self::ValueWires { let u256_zero = b.zero_u256(); let curve_zero = b.curve_zero(); - // Initialize the input wires. - let input_wires = InputWires { - selector: b.add_virtual_target_arr(), - ids: b.add_virtual_target_arr(), - is_output_valid: [0; MAX_NUM_RESULTS].map(|_| b.add_virtual_bool_target_safe()), - }; - // Build the output items to be returned. let output_items: [_; MAX_NUM_RESULTS] = array::from_fn(|i| { b.random_access_u256(input_wires.selector[i], &possible_output_values) @@ -165,6 +206,24 @@ impl OutputComponent for Circuit< // Set the remaining outputs to dummy values. let output_values = vec![u256_zero; MAX_NUM_RESULTS - 1]; + ValueWires { + first_output_value, + output_values, + } + } + + fn build_hash( + b: &mut CBuilder, + possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], + predicate_hash: &ComputationalHashTarget, + ) -> Self::HashWires { + // Initialize the input wires. + let input_wires = InputWires { + selector: b.add_virtual_target_arr(), + ids: b.add_virtual_target_arr(), + is_output_valid: [0; MAX_NUM_RESULTS].map(|_| b.add_virtual_bool_target_safe()), + }; + // Compute the computational hash representing the accumulation of the items. let output_hash = Self::output_variant().output_hash_circuit( b, @@ -186,52 +245,12 @@ impl OutputComponent for Circuit< .collect(); let ops_ids = ops_ids.try_into().unwrap(); - Self::Wires { + HashWires { input_wires, - first_output_value, - output_values, output_hash, ops_ids, } } - - fn assign(&self, pw: &mut PartialWitness, wires: &InputWires) { - pw.set_target_arr(wires.selector.as_slice(), self.selector.as_slice()); - pw.set_target_arr(wires.ids.as_slice(), self.ids.as_slice()); - wires - .is_output_valid - .iter() - .enumerate() - .for_each(|(i, t)| pw.set_bool_target(*t, i < self.valid_num_outputs)); - } - - fn new(selector: &[F], ids: &[F], num_outputs: usize) -> anyhow::Result { - ensure!(selector.len() == num_outputs, - "Output component without aggregation: Number of selectors different from number of actual outputs"); - ensure!(ids.len() == num_outputs, - "Output component without aggregation: Number of output ids different from number of actual outputs"); - let selectors = selector - .iter() - .chain(repeat(&F::default())) - .take(MAX_NUM_RESULTS) - .cloned() - .collect_vec(); - let output_ids = ids - .iter() - .chain(repeat(&F::default())) - .take(MAX_NUM_RESULTS) - .cloned() - .collect_vec(); - Ok(Self { - valid_num_outputs: num_outputs, - selector: selectors.try_into().unwrap(), - ids: output_ids.try_into().unwrap(), - }) - } - - fn output_variant() -> Output { - Output::NoAggregation - } } #[cfg(test)] @@ -497,7 +516,7 @@ mod tests { { // Circuit wires + output wires + expected wires type Wires = ( - Wires, + InputWires, TestOutputWires, TestExpectedWires, ); @@ -528,11 +547,15 @@ mod tests { ); // Check the first output value and the output hash as expected. - b.connect_curve_points(wires.first_output_value, expected.first_output_value); - b.connect_hashes(wires.output_hash, expected.output_hash); + b.connect_curve_points( + wires.value_wires.first_output_value, + expected.first_output_value, + ); + b.connect_hashes(wires.hash_wires.output_hash, expected.output_hash); // Check the remaining output values must be all zeros. wires + .value_wires .output_values .iter() .for_each(|t| b.enforce_equal_u256(t, &u256_zero)); @@ -540,16 +563,16 @@ mod tests { // Check the first OP is ID, and the remainings are SUM. let [op_id, op_sum] = [AggregationOperation::IdOp, AggregationOperation::SumOp] .map(|op| b.constant(Identifiers::AggregationOperations(op).to_field())); - b.connect(wires.ops_ids[0], op_id); - wires.ops_ids[1..] + b.connect(wires.hash_wires.ops_ids[0], op_id); + wires.hash_wires.ops_ids[1..] .iter() .for_each(|t| b.connect(*t, op_sum)); - (wires, output, expected) + (wires.hash_wires.input_wires, output, expected) } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign(pw, &wires.0.input_wires); + self.c.assign(pw, &wires.0); self.output.assign(pw, &wires.1); self.expected.assign(pw, &wires.2); } diff --git a/verifiable-db/src/query/universal_circuit/output_with_aggregation.rs b/verifiable-db/src/query/universal_circuit/output_with_aggregation.rs index c1705f142..cee481838 100644 --- a/verifiable-db/src/query/universal_circuit/output_with_aggregation.rs +++ b/verifiable-db/src/query/universal_circuit/output_with_aggregation.rs @@ -21,7 +21,9 @@ use serde::{Deserialize, Serialize}; use crate::query::computational_hash_ids::{AggregationOperation, Output}; use super::{ - universal_query_circuit::{OutputComponent, OutputComponentWires}, + universal_query_gadget::{ + OutputComponent, OutputComponentHashWires, OutputComponentValueWires, + }, ComputationalHashTarget, }; @@ -52,16 +54,19 @@ pub struct InputWires { )] is_output_valid: [BoolTarget; MAX_NUM_RESULTS], } - #[derive(Clone, Debug, Eq, PartialEq)] -/// Input + output wires for output component for queries with result aggregation -pub struct Wires { +pub struct HashWires { input_wires: InputWires, - /// Output values computed by this component - output_values: [UInt256Target; MAX_NUM_RESULTS], /// Computational hash representing all the computation done in the query circuit output_hash: ComputationalHashTarget, } + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ValueWires { + /// Output values computed by this component + output_values: [UInt256Target; MAX_NUM_RESULTS], +} + /// Input witness values for output component for queries with result aggregation #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct Circuit { @@ -78,15 +83,9 @@ pub struct Circuit { num_valid_outputs: usize, } -impl OutputComponentWires for Wires { +impl OutputComponentValueWires for ValueWires { type FirstT = UInt256Target; - type InputWires = InputWires; - - fn ops_ids(&self) -> &[Target] { - self.input_wires.agg_ops.as_slice() - } - fn first_output_value(&self) -> Self::FirstT { self.output_values[0].clone() } @@ -94,6 +93,14 @@ impl OutputComponentWires for Wires &[UInt256Target] { &self.output_values[1..] } +} + +impl OutputComponentHashWires for HashWires { + type InputWires = InputWires; + + fn ops_ids(&self) -> &[Target] { + self.input_wires.agg_ops.as_slice() + } fn computational_hash(&self) -> ComputationalHashTarget { self.output_hash @@ -105,58 +112,8 @@ impl OutputComponentWires for Wires OutputComponent for Circuit { - type Wires = Wires; - - fn build( - b: &mut CBuilder, - possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], - possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], - predicate_value: &BoolTarget, - predicate_hash: &ComputationalHashTarget, - ) -> Self::Wires { - let selector = b.add_virtual_target_arr::(); - let agg_ops = b.add_virtual_target_arr::(); - let is_output_valid = array::from_fn(|_| b.add_virtual_bool_target_safe()); - let u256_max = b.constant_u256(U256::MAX); - let zero = b.zero_u256(); - let min_op_identifier = b.constant(AggregationOperation::MinOp.to_field()); - - let mut output_values = vec![]; - - for i in 0..MAX_NUM_RESULTS { - // TODO: random accesses over different iterations can be done with a single operation if we introduce an ad-hoc gate - let output_value = b.random_access_u256(selector[i], &possible_output_values); - - // If `predicate_value` is true, then expose the value to be aggregated; - // Otherwise use the identity for the aggregation operation. - // The identity is 0 except for "MIN", where the identity is the biggest - // possible value in the domain, i.e. 2^256-1. - let is_agg_ops_min = b.is_equal(agg_ops[i], min_op_identifier); - let identity_value = b.select_u256(is_agg_ops_min, &u256_max, &zero); - let actual_output_value = - b.select_u256(*predicate_value, &output_value, &identity_value); - output_values.push(actual_output_value); - } - - let output_hash = Self::output_variant().output_hash_circuit( - b, - predicate_hash, - &possible_output_hash, - &selector, - &agg_ops, - &is_output_valid, - ); - - Wires { - input_wires: InputWires { - selector, - agg_ops, - is_output_valid, - }, - output_values: output_values.try_into().unwrap(), - output_hash, - } - } + type HashWires = HashWires; + type ValueWires = ValueWires; fn assign(&self, pw: &mut PartialWitness, wires: &InputWires) { pw.set_target_arr(wires.selector.as_slice(), self.selector.as_slice()); @@ -195,6 +152,68 @@ impl OutputComponent for Circuit< fn output_variant() -> Output { Output::Aggregation } + + fn build_values( + b: &mut CBuilder, + possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], + predicate_value: &BoolTarget, + input_wires: &::InputWires, + ) -> Self::ValueWires { + let u256_max = b.constant_u256(U256::MAX); + let zero = b.zero_u256(); + let min_op_identifier = b.constant(AggregationOperation::MinOp.to_field()); + + let mut output_values = vec![]; + + for i in 0..MAX_NUM_RESULTS { + // TODO: random accesses over different iterations can be done with a single operation if we introduce an ad-hoc gate + let output_value = + b.random_access_u256(input_wires.selector[i], &possible_output_values); + + // If `predicate_value` is true, then expose the value to be aggregated; + // Otherwise use the identity for the aggregation operation. + // The identity is 0 except for "MIN", where the identity is the biggest + // possible value in the domain, i.e. 2^256-1. + let is_agg_ops_min = b.is_equal(input_wires.agg_ops[i], min_op_identifier); + let identity_value = b.select_u256(is_agg_ops_min, &u256_max, &zero); + let actual_output_value = + b.select_u256(*predicate_value, &output_value, &identity_value); + output_values.push(actual_output_value); + } + + ValueWires { + output_values: output_values.try_into().unwrap(), + } + } + + fn build_hash( + b: &mut CBuilder, + possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], + predicate_hash: &ComputationalHashTarget, + ) -> Self::HashWires { + let selector = b.add_virtual_target_arr::(); + let agg_ops = b.add_virtual_target_arr::(); + let is_output_valid = array::from_fn(|_| b.add_virtual_bool_target_safe()); + let input_wires = InputWires { + selector, + agg_ops, + is_output_valid, + }; + + let output_hash = Self::output_variant().output_hash_circuit( + b, + predicate_hash, + &possible_output_hash, + &input_wires.selector, + &input_wires.agg_ops, + &input_wires.is_output_valid, + ); + + HashWires { + input_wires, + output_hash, + } + } } #[cfg(test)] @@ -226,7 +245,7 @@ mod tests { computational_hash_ids::{AggregationOperation, ComputationalHashCache}, universal_circuit::{ universal_circuit_inputs::OutputItem, - universal_query_circuit::{OutputComponent, OutputComponentWires}, + universal_query_gadget::{OutputComponent, OutputComponentHashWires}, ComputationalHash, ComputationalHashTarget, }, }; @@ -311,14 +330,14 @@ mod tests { expected_output_values .iter() - .zip(wires.output_values.iter()) + .zip(wires.value_wires.output_values.iter()) .for_each(|(expected, actual)| c.enforce_equal_u256(expected, actual)); expected_ops_ids .iter() - .zip(wires.ops_ids().iter()) + .zip(wires.hash_wires.ops_ids().iter()) .for_each(|(expected, actual)| c.connect(*expected, *actual)); - c.connect_hashes(expected_output_hash, wires.output_hash); + c.connect_hashes(expected_output_hash, wires.hash_wires.output_hash); Self::Wires { column_values, @@ -327,7 +346,7 @@ mod tests { item_hash: item_hash.try_into().unwrap(), predicate_value, predicate_hash, - component: wires.input_wires, + component: wires.hash_wires.input_wires, expected_output_values, expected_ops_ids, expected_output_hash, diff --git a/verifiable-db/src/query/universal_circuit/universal_circuit_inputs.rs b/verifiable-db/src/query/universal_circuit/universal_circuit_inputs.rs index 76d56a471..317fef383 100644 --- a/verifiable-db/src/query/universal_circuit/universal_circuit_inputs.rs +++ b/verifiable-db/src/query/universal_circuit/universal_circuit_inputs.rs @@ -5,11 +5,14 @@ use std::collections::{btree_set, BTreeSet, HashMap}; use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ + array::ToField, utils::{Fieldable, TryIntoBool}, F, }; -use crate::query::computational_hash_ids::{Operation, Output, PlaceholderIdentifier}; +use crate::query::computational_hash_ids::{ + AggregationOperation, ColumnIDs, Identifiers, Operation, Output, PlaceholderIdentifier, +}; use super::universal_query_circuit::dummy_placeholder_id; @@ -321,6 +324,18 @@ impl BasicOperation { Ok((results, arithmetic_error)) } + + // utility function to locate operation `op` in the set of `previous_ops` + #[cfg(test)] // used only in test for now + pub(crate) fn locate_previous_operation(previous_ops: &[Self], op: &Self) -> Result { + previous_ops + .iter() + .find_position(|current_op| *current_op == op) + .map(|(pos, _)| pos) + .ok_or(anyhow::Error::msg( + "operation {} not found in set of previous ops", + )) + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -411,6 +426,15 @@ impl ResultStructure { }) } + pub fn aggregation_operations(&self) -> Vec { + match self.query_variant() { + Output::Aggregation => self.output_ids.clone(), + Output::NoAggregation => { + vec![Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field()] + } + } + } + pub fn query_variant(&self) -> Output { self.output_variant } @@ -483,4 +507,12 @@ impl RowCells { .cloned() .collect_vec() } + + pub fn column_ids(&self) -> ColumnIDs { + ColumnIDs { + primary: self.primary.id, + secondary: self.secondary.id, + rest: self.rest.iter().map(|cell| cell.id).collect_vec(), + } + } } diff --git a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs index 1e3ec0919..3a8522ee9 100644 --- a/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs +++ b/verifiable-db/src/query/universal_circuit/universal_query_circuit.rs @@ -1,20 +1,22 @@ -use std::{ - fmt::Debug, - iter::{once, repeat}, -}; +use std::iter::once; -use alloy::primitives::U256; -use anyhow::{bail, ensure, Result}; +use crate::query::{ + computational_hash_ids::{Output, PlaceholderIdentifier}, + pi_len, + public_inputs::PublicInputsUniversalCircuit, + row_chunk_gadgets::BoundaryRowDataTarget, + utils::QueryBounds, +}; +use anyhow::Result; use itertools::Itertools; use mp2_common::{ array::ToField, - poseidon::{empty_poseidon_hash, H}, + poseidon::{empty_poseidon_hash, HashPermutation}, public_inputs::PublicInputCommon, - serialization::{deserialize, deserialize_long_array, serialize, serialize_long_array}, + serialization::{deserialize, serialize}, types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target}, - utils::{FromFields, SelectHashBuilder, ToFields, ToTargets}, - CHasher, D, F, + utils::{FromTargets, HashBuilder, ToFields, ToTargets}, + CHasher, C, D, F, }; use plonky2::{ field::types::Field, @@ -25,360 +27,26 @@ use plonky2::{ }, plonk::{ circuit_builder::CircuitBuilder, - config::{GenericHashOut, Hasher}, - proof::ProofWithPublicInputsTarget, + circuit_data::{CircuitConfig, CircuitData}, + proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, }, }; use recursion_framework::circuit_builder::CircuitLogicWires; -use serde::{Deserialize, Serialize}; - -use crate::query::{ - aggregation::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, - computational_hash_ids::{ - ComputationalHashCache, HashPermutation, Operation, Output, PlaceholderIdentifier, - }, - pi_len, - public_inputs::PublicInputs, - universal_circuit::{ - basic_operation::BasicOperationInputs, universal_circuit_inputs::OutputItem, - }, -}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use super::{ - basic_operation::{BasicOperationInputWires, BasicOperationWires}, - column_extraction::{ColumnExtractionInputWires, ColumnExtractionInputs}, output_no_aggregation::Circuit as NoAggOutputCircuit, output_with_aggregation::Circuit as AggOutputCircuit, universal_circuit_inputs::{ - BasicOperation, InputOperand, Placeholder, PlaceholderId, Placeholders, ResultStructure, - RowCells, + BasicOperation, PlaceholderId, Placeholders, ResultStructure, RowCells, + }, + universal_query_gadget::{ + OutputComponent, QueryBound, UniversalQueryHashInputWires, UniversalQueryHashInputs, + UniversalQueryValueInputWires, UniversalQueryValueInputs, }, - ComputationalHash, ComputationalHashTarget, PlaceholderHash, PlaceholderHashTarget, + PlaceholderHash, }; -/// Wires representing a query bound in the universal circuit -pub(crate) type QueryBoundTarget = BasicOperationWires; - -/// Input wires for `QueryBoundTarget` (i.e., the wires that need to be assigned) -pub(crate) type QueryBoundTargetInputs = BasicOperationInputWires; - -impl From for QueryBoundTargetInputs { - fn from(value: QueryBoundTarget) -> Self { - value.input_wires - } -} - -impl QueryBoundTarget { - pub(crate) fn new(b: &mut CBuilder) -> Self { - let zero_u256 = b.zero_u256(); - let zero = b.zero(); - let empty_hash = b.constant_hash(*empty_poseidon_hash()); - // The 0 constant provided as input value is used as a dummy operand in case the query bound - // is taken from a constant in the query: in this case, the query bound in the circuit is - // computed with the operation `InputOperand::Constant(query_bound) + input_values[0]`, which - // yields `query_bound` as output since `input_values[0] = 0`. The constant input values 0 is - // associated to the empty hash in the computational hash, which is provided as `input_hash[0]` - BasicOperationInputs::build(b, &[zero_u256], &[empty_hash], zero) - } - - /// Get the actual value of this query bound computed in the circuit - pub(crate) fn get_bound_value(&self) -> &UInt256Target { - &self.output_value - } - - // Compute the number of overflows occurred during operations to compute query bounds - pub(crate) fn num_overflows_for_query_bound_operations( - b: &mut CBuilder, - min_query: &Self, - max_query: &Self, - ) -> Target { - b.add(min_query.num_overflows, max_query.num_overflows) - } - - pub(crate) fn add_query_bounds_to_placeholder_hash( - b: &mut CBuilder, - min_query_bound: &Self, - max_query_bound: &Self, - placeholder_hash: &PlaceholderHashTarget, - ) -> PlaceholderHashTarget { - b.hash_n_to_hash_no_pad::( - placeholder_hash - .elements - .iter() - .chain(once(&min_query_bound.input_wires.placeholder_ids[0])) - .chain(&min_query_bound.input_wires.placeholder_values[0].to_targets()) - .chain(once(&min_query_bound.input_wires.placeholder_ids[1])) - .chain(&min_query_bound.input_wires.placeholder_values[1].to_targets()) - .chain(once(&max_query_bound.input_wires.placeholder_ids[0])) - .chain(&max_query_bound.input_wires.placeholder_values[0].to_targets()) - .chain(once(&max_query_bound.input_wires.placeholder_ids[1])) - .chain(&max_query_bound.input_wires.placeholder_values[1].to_targets()) - .cloned() - .collect(), - ) - } - - pub(crate) fn add_query_bounds_to_computational_hash( - b: &mut CBuilder, - min_query_bound: &Self, - max_query_bound: &Self, - computational_hash: &ComputationalHashTarget, - ) -> ComputationalHashTarget { - b.hash_n_to_hash_no_pad::( - computational_hash - .to_targets() - .into_iter() - .chain(min_query_bound.output_hash.to_targets()) - .chain(max_query_bound.output_hash.to_targets()) - .collect_vec(), - ) - } -} - -impl QueryBoundTargetInputs { - pub(crate) fn assign(&self, pw: &mut PartialWitness, bound: &QueryBound) { - bound.operation.assign(pw, self); - } -} -#[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct QueryBound { - pub(crate) operation: BasicOperationInputs, -} - -impl QueryBound { - /// Number of input values provided to the basic operation component computing the query bounds - /// in the circuit; currently it is 1 since the constant input value 0 is provided as a dummy - /// input value (see QueryBoundTarget::new()). - const NUM_INPUT_VALUES: usize = 1; - - /// Initialize a query bound for the primary index, from the set of `placeholders` employed in the query, - /// which include also the primary index bounds by construction. The flag `is_min_bound` - /// must be true iff the bound to be initialized is a lower bound in the range specified in the query - pub(crate) fn new_primary_index_bound( - placeholders: &Placeholders, - is_min_bound: bool, - ) -> Result { - let source = QueryBoundSource::Placeholder(if is_min_bound { - PlaceholderIdentifier::MinQueryOnIdx1 - } else { - PlaceholderIdentifier::MaxQueryOnIdx1 - }); - Self::new_bound(placeholders, &source) - } - - /// Initialize a query bound for the secondary index, from the set of placeholders employed in the query - /// and from the provided `bound`, which specifies how the query bound should be computed in the circuit - pub(crate) fn new_secondary_index_bound( - placeholders: &Placeholders, - bound: &QueryBoundSecondary, - ) -> Result { - let source = bound.into(); - Self::new_bound(placeholders, &source) - } - - /// Internal function employed to instantiate a new query bound - fn new_bound(placeholders: &Placeholders, source: &QueryBoundSource) -> Result { - let dummy_placeholder = dummy_placeholder(placeholders); - let op_inputs = match source { - QueryBoundSource::Constant(value) => - // if the query bound is computed from a constant `value`, we instantiate the operation - // `value + input_values[0]` in the circuit, as in `QueryBoundTarget` construction we - // always set `input_values[0] = 0`. This trick allows to get the same constant `value` - // as output of the basic operation employed in the circuit to compute the query bound - { - BasicOperationInputs { - constant_operand: *value, - placeholder_values: [dummy_placeholder.value, dummy_placeholder.value], - placeholder_ids: [ - dummy_placeholder.id.to_field(), - dummy_placeholder.id.to_field(), - ], - first_input_selector: BasicOperationInputs::constant_operand_offset( - Self::NUM_INPUT_VALUES, - ) - .to_field(), - second_input_selector: BasicOperationInputs::input_value_offset(0).to_field(), - op_selector: Operation::AddOp.to_field(), - } - } - QueryBoundSource::Placeholder(id) => - // if the query bound is computed from a placeholder with id `id`, we instantiate - // the operation `$id + 0` in the circuit, which will yield the value of placeholder - // $id (which should correspond to the query bound) as output - { - BasicOperationInputs { - constant_operand: U256::ZERO, - placeholder_values: [placeholders.get(id)?, dummy_placeholder.value], - placeholder_ids: [id.to_field(), dummy_placeholder.id.to_field()], - first_input_selector: BasicOperationInputs::first_placeholder_offset( - Self::NUM_INPUT_VALUES, - ) - .to_field(), - second_input_selector: BasicOperationInputs::constant_operand_offset( - Self::NUM_INPUT_VALUES, - ) - .to_field(), - op_selector: Operation::AddOp.to_field(), - } - } - QueryBoundSource::Operation(op) => { - // In this case we instantiate the basic operation `op`, checking that the operation - // satisfies the requirements for query bound operations (i.e., it involves only - // constant values and placeholders) - let mut constant_operand = U256::ZERO; - let mut process_input_op = |operand: &InputOperand| { - Ok(match operand { - InputOperand::Placeholder(id) => - ( - *id, - None, - ), - InputOperand::Constant(value) => { - constant_operand = *value; - ( - dummy_placeholder.id, - Some(BasicOperationInputs::constant_operand_offset(Self::NUM_INPUT_VALUES)) - ) - }, - _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant"), - }) - }; - - let (first_placeholder_id, first_selector) = process_input_op(&op.first_operand)?; - let (second_placeholder_id, second_selector) = process_input_op( - &op.second_operand.unwrap_or_default(), // Unary operation, so use a dummy operand - )?; - BasicOperationInputs { - constant_operand, - placeholder_values: [ - placeholders.get(&first_placeholder_id)?, - placeholders.get(&second_placeholder_id)?, - ], - placeholder_ids: [ - first_placeholder_id.to_field(), - second_placeholder_id.to_field(), - ], - first_input_selector: first_selector - .unwrap_or(BasicOperationInputs::first_placeholder_offset( - Self::NUM_INPUT_VALUES, - )) - .to_field(), - second_input_selector: second_selector - .unwrap_or(BasicOperationInputs::second_placeholder_offset( - Self::NUM_INPUT_VALUES, - )) - .to_field(), - op_selector: op.op.to_field(), - } - } - }; - Ok(Self { - operation: op_inputs, - }) - } - - /// This method computes the value of a query bound - pub(crate) fn compute_bound_value( - placeholders: &Placeholders, - source: &QueryBoundSource, - ) -> Result<(U256, bool)> { - Ok(match source { - QueryBoundSource::Constant(value) => (*value, false), - QueryBoundSource::Placeholder(id) => (placeholders.get(id)?, false), - QueryBoundSource::Operation(op) => { - let (values, overflow) = - BasicOperation::compute_operations(&[*op], &[], placeholders)?; - (values[0], overflow) - } - }) - } - - /// This method returns the basic operation employed in the circuit for the query bound which is - /// taken fromthe query as specify by the input `source`. It basically returns the same operations - /// that are instantiated in the circuit by the `new_bound` internal method - pub(crate) fn get_basic_operation(source: &QueryBoundSource) -> Result { - Ok(match source { - QueryBoundSource::Constant(value) => - // convert to operation `value + input_value[0]`, which yield value as `input_value[0] = 0` in the circuit - { - BasicOperation { - first_operand: InputOperand::Constant(*value), - second_operand: Some(InputOperand::Column(0)), - op: Operation::AddOp, - } - } - QueryBoundSource::Placeholder(id) => - // convert to operation $id + 0 - { - BasicOperation { - first_operand: InputOperand::Placeholder(*id), - second_operand: Some(InputOperand::Constant(U256::ZERO)), - op: Operation::AddOp, - } - } - QueryBoundSource::Operation(op) => { - // validate operation for query bound - match op.first_operand { - InputOperand::Constant(_) | InputOperand::Placeholder(_) => (), - _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant") - } - if let Some(operand) = op.second_operand { - match operand { - InputOperand::Constant(_) | InputOperand::Placeholder(_) => (), - _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant") - } - } - *op - } - }) - } - - pub(crate) fn add_secondary_query_bounds_to_placeholder_hash( - min_query: &Self, - max_query: &Self, - placeholder_hash: &PlaceholderHash, - ) -> PlaceholderHash { - hash_n_to_hash_no_pad::<_, HashPermutation>( - &placeholder_hash - .to_vec() - .into_iter() - .chain(once(min_query.operation.placeholder_ids[0])) - .chain(min_query.operation.placeholder_values[0].to_fields()) - .chain(once(min_query.operation.placeholder_ids[1])) - .chain(min_query.operation.placeholder_values[1].to_fields()) - .chain(once(max_query.operation.placeholder_ids[0])) - .chain(max_query.operation.placeholder_values[0].to_fields()) - .chain(once(max_query.operation.placeholder_ids[1])) - .chain(max_query.operation.placeholder_values[1].to_fields()) - .collect_vec(), - ) - } - - pub(crate) fn add_secondary_query_bounds_to_computational_hash( - min_query: &QueryBoundSource, - max_query: &QueryBoundSource, - computational_hash: &ComputationalHash, - ) -> Result { - let min_query_op = Self::get_basic_operation(min_query)?; - let max_query_op = Self::get_basic_operation(max_query)?; - // initialize computational hash cache with the empty hash associated to the only input value (hardcoded to 0 - // in the circuit) of the basic operation components employed for query bounds - let mut cache = ComputationalHashCache::new_from_column_hash( - Self::NUM_INPUT_VALUES, - &[*empty_poseidon_hash()], - )?; - let min_query_hash = Operation::basic_operation_hash(&mut cache, &[], &min_query_op)?; - let max_query_hash = Operation::basic_operation_hash(&mut cache, &[], &max_query_op)?; - let inputs = computational_hash - .to_vec() - .into_iter() - .chain(min_query_hash.to_fields()) - .chain(max_query_hash.to_fields()) - .collect_vec(); - Ok(H::hash_no_pad(&inputs)) - } -} - #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] /// Input wires for the universal query circuit pub struct UniversalQueryCircuitWires< @@ -388,83 +56,19 @@ pub struct UniversalQueryCircuitWires< const MAX_NUM_RESULTS: usize, T: OutputComponent, > { - /// Input wires for column extraction component - pub(crate) column_extraction_wires: ColumnExtractionInputWires, /// flag specifying whether the given row is stored in a leaf node of a rows tree or not #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] is_leaf: BoolTarget, - /// Lower bound of the range for the secondary index specified in the query - min_query: QueryBoundTargetInputs, - /// Upper bound of the range for the secondary index specified in the query - max_query: QueryBoundTargetInputs, - /// Input wires for the `MAX_NUM_PREDICATE_OPS` basic operation components necessary - /// to evaluate the filtering predicate - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - filtering_predicate_ops: [BasicOperationInputWires; MAX_NUM_PREDICATE_OPS], - /// Input wires for the `MAX_NUM_RESULT_OPS` basic operation components necessary - /// to compute the results for the current row - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - result_value_ops: [BasicOperationInputWires; MAX_NUM_RESULT_OPS], - /// Input wires for the `MAX_NUM_RESULTS` output components that computes the - /// output values for the current row - output_component_wires: ::InputWires, -} - -/// Trait for the 2 different variants of output components we currently support -/// in query circuits -pub trait OutputComponent: Clone { - type Wires: OutputComponentWires; - - fn new(selector: &[F], ids: &[F], num_outputs: usize) -> Result; - - fn build( - b: &mut CBuilder, - possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], - possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], - predicate_value: &BoolTarget, - predicate_hash: &ComputationalHashTarget, - ) -> Self::Wires; - - fn assign( - &self, - pw: &mut PartialWitness, - wires: &::InputWires, - ); - - /// Return the type of output component, specified as an instance of `Output` enum - fn output_variant() -> Output; + hash_wires: UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + value_wires: UniversalQueryValueInputWires, } -/// Trait representing the wires that need to be exposed by an `OutputComponent` -/// employed in query circuits -pub trait OutputComponentWires { - /// Associated type specifying the type of the first output value computed by this output - /// component; this type varies depending on the particular component: - /// - It is a `CurveTarget` in the output component for queries without aggregation operations - /// - It is a `UInt256Target` in the output for queries with aggregation operations - type FirstT: ToTargets; - /// Input wires of the output component - type InputWires: Serialize + for<'a> Deserialize<'a> + Clone + Debug + Eq + PartialEq; - /// Get the identifiers of the aggregation operations specified in the query to aggregate the - /// results (e.g., `SUM`, `AVG`) - fn ops_ids(&self) -> &[Target]; - /// Get the first output value returned by the output component; this is accessed by an ad-hoc - /// method since such output value could be a `UInt256Target` or a `CurveTarget`, depending - /// on the output component instance - fn first_output_value(&self) -> Self::FirstT; - /// Get the subsequent output values returned by the output component - fn other_output_values(&self) -> &[UInt256Target]; - /// Get the computational hash returned by the output component - fn computational_hash(&self) -> ComputationalHashTarget; - /// Get the input wires for the output component - fn input_wires(&self) -> Self::InputWires; -} /// Witness input values for the universal query circuit #[derive(Clone, Debug, Serialize, Deserialize)] pub struct UniversalQueryCircuitInputs< @@ -474,21 +78,20 @@ pub struct UniversalQueryCircuitInputs< const MAX_NUM_RESULTS: usize, T: OutputComponent, > { - column_extraction_inputs: ColumnExtractionInputs, is_leaf: bool, - min_query: QueryBound, - max_query: QueryBound, - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - filtering_predicate_inputs: [BasicOperationInputs; MAX_NUM_PREDICATE_OPS], - #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" - )] - result_values_inputs: [BasicOperationInputs; MAX_NUM_RESULT_OPS], - output_component_inputs: T, + hash_gadget_inputs: UniversalQueryHashInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + value_gadget_inputs: UniversalQueryValueInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >, } impl< @@ -521,89 +124,22 @@ where is_leaf: bool, query_bounds: &QueryBounds, results: &ResultStructure, + is_dummy_row: bool, ) -> Result { - let num_columns = row_cells.num_columns(); - ensure!( - num_columns <= MAX_NUM_COLUMNS, - "number of columns is higher than the maximum value allowed" - ); - let column_cells = row_cells.to_cells(); - let padded_column_values = column_cells - .iter() - .map(|cell| cell.value) - .chain(repeat(U256::ZERO)) - .take(MAX_NUM_COLUMNS) - .collect_vec(); - let padded_column_ids = column_cells - .iter() - .map(|cell| cell.id) - .chain(repeat(F::NEG_ONE)) - .take(MAX_NUM_COLUMNS) - .collect_vec(); - let column_extraction_inputs = ColumnExtractionInputs:: { - real_num_columns: num_columns, - column_values: padded_column_values.try_into().unwrap(), - column_ids: padded_column_ids.try_into().unwrap(), - }; - let num_predicate_ops = predicate_operations.len(); - ensure!(num_predicate_ops <= MAX_NUM_PREDICATE_OPS, - "Number of operations to compute filtering predicate is higher than the maximum number allowed"); - let num_result_ops = results.result_operations.len(); - ensure!( - num_result_ops <= MAX_NUM_RESULT_OPS, - "Number of operations to compute results is higher than the maximum number allowed" - ); - let predicate_ops_inputs = Self::compute_operation_inputs::( + let hash_gadget_inputs = UniversalQueryHashInputs::new( + &row_cells.column_ids(), predicate_operations, placeholders, + query_bounds, + results, )?; - let result_ops_inputs = Self::compute_operation_inputs::( - &results.result_operations, - placeholders, - )?; - let selectors = results.output_items.iter().enumerate().map(|(i, item)| { - Ok( - match item { - OutputItem::Column(index) => { - ensure!(*index < MAX_NUM_COLUMNS, - "Column index provided as {}-th output value is higher than the maximum number of columns", i); - F::from_canonical_usize(*index) - }, - OutputItem::ComputedValue(index) => { - ensure!(*index < num_result_ops, - "an operation computing an output results not found in set of result operations"); - // the output will be placed in the `num_result_ops - index` last slot in the set of - // `possible_output_values` provided as input in the circuit to the output component, - // i.e., the input array found in `OutputComponent::build` method. - // Therefore, since the `possible_output_values` array in the circuit has - // `MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS` entries, the selector for such output value - // can be computed as the length of `possible_output_values.len() - (num_result_ops - index)`, - // which correspond to the `num_result_ops - index`-th entry from the end of the array - F::from_canonical_usize(MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS - (num_result_ops - *index)) - }, - }) - }).collect::>>()?; - let output_component_inputs = - T::new(&selectors, &results.output_ids, results.output_ids.len())?; - let min_query = QueryBound::new_secondary_index_bound( - placeholders, - query_bounds.min_query_secondary(), - )?; - - let max_query = QueryBound::new_secondary_index_bound( - placeholders, - query_bounds.max_query_secondary(), - )?; + let value_gadget_inputs = UniversalQueryValueInputs::new(row_cells, is_dummy_row)?; Ok(Self { - column_extraction_inputs, is_leaf, - min_query, - max_query, - filtering_predicate_inputs: predicate_ops_inputs, - result_values_inputs: result_ops_inputs, - output_component_inputs, + hash_gadget_inputs, + value_gadget_inputs, }) } @@ -616,20 +152,20 @@ where MAX_NUM_RESULTS, T, > { - let column_extraction_wires = ColumnExtractionInputs::::build(b); + let hash_wires = UniversalQueryHashInputs::build(b); + let value_wires = UniversalQueryValueInputs::build( + b, + &hash_wires.input_wires, + &hash_wires.min_secondary, + &hash_wires.max_secondary, + &hash_wires.num_bound_overflows, + ); let is_leaf = b.add_virtual_bool_target_safe(); let _true = b._true(); let zero = b.zero(); // min and max for secondary indexed column - let node_min = &column_extraction_wires.input_wires.column_values[1]; + let node_min = &value_wires.input_wires.column_values[1]; let node_max = node_min; - // column ids for primary and seconday indexed columns - let (primary_index_id, second_index_id) = ( - &column_extraction_wires.input_wires.column_ids[0], - &column_extraction_wires.input_wires.column_ids[1], - ); - // value of the primary indexed column for the current row - let index_value = &column_extraction_wires.input_wires.column_values[0]; // compute hash of the node in case the current row is stored in a leaf of the rows tree let empty_hash = b.constant_hash(*empty_poseidon_hash()); let leaf_hash_inputs = empty_hash @@ -638,167 +174,48 @@ where .chain(empty_hash.elements.iter()) .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) - .chain(once(second_index_id)) + .chain(once( + &hash_wires.input_wires.column_extraction_wires.column_ids[1], + )) .chain(node_min.to_targets().iter()) - .chain(column_extraction_wires.tree_hash.elements.iter()) + .chain(value_wires.output_wires.tree_hash.elements.iter()) .cloned() .collect(); let leaf_hash = b.hash_n_to_hash_no_pad::(leaf_hash_inputs); - let tree_hash = b.select_hash(is_leaf, &leaf_hash, &column_extraction_wires.tree_hash); - // ensure that the value of second indexed column for the current record is in - // the range specified by the query - let min_query = QueryBoundTarget::new(b); - let max_query = QueryBoundTarget::new(b); - let min_query_value = min_query.get_bound_value(); - let max_query_value = max_query.get_bound_value(); - let less_than_max = b.is_less_or_equal_than_u256(node_max, max_query_value); - let greater_than_min = b.is_less_or_equal_than_u256(min_query_value, node_min); - b.connect(less_than_max.target, _true.target); - b.connect(greater_than_min.target, _true.target); - // initialize input_values and input_hash input vectors for basic operation components employed to - // evaluate the filtering predicate - let mut input_values = column_extraction_wires.input_wires.column_values.to_vec(); - let mut input_hash = column_extraction_wires.column_hash.to_vec(); - // Set of input wires for each of the `MAX_NUM_PREDICATE_OPS` basic operation components employed to - // evaluate the filtering predicate - let mut filtering_predicate_wires = Vec::with_capacity(MAX_NUM_PREDICATE_OPS); - // Payload to compute the placeholder hash public input - let mut placeholder_hash_payload = vec![]; - // initialize counter of overflows to number of overflows occurred during query bound operations - let mut num_overflows = - QueryBoundTarget::num_overflows_for_query_bound_operations(b, &min_query, &max_query); + let tree_hash = b.select_hash(is_leaf, &leaf_hash, &value_wires.output_wires.tree_hash); - for _ in 0..MAX_NUM_PREDICATE_OPS { - let BasicOperationWires { - input_wires, - output_value, - output_hash, - num_overflows: new_num_overflows, - } = BasicOperationInputs::build(b, &input_values, &input_hash, num_overflows); - // add the output_value computed by the last basic operation component to the input values - // for the next basic operation components employed to evaluate the filtering predicate - input_values.push(output_value); - // and the corresponding output_hash to the input hash as well - input_hash.push(output_hash); - // update the counter of overflows detected - num_overflows = new_num_overflows; - // add placeholder data to payload for placeholder hash - placeholder_hash_payload.push(input_wires.placeholder_ids[0]); - placeholder_hash_payload - .extend_from_slice(&input_wires.placeholder_values[0].to_targets()); - placeholder_hash_payload.push(input_wires.placeholder_ids[1]); - placeholder_hash_payload - .extend_from_slice(&input_wires.placeholder_values[1].to_targets()); - filtering_predicate_wires.push(input_wires); - } - // Place the evaluation of the filtering predicate, and the corresponding computational hash, in - // two variables; the evaluation and the corresponding hash are expected to be the output of the - // last basic operation component among the `MAX_NUM_PREDICATE_OPS` ones employed to evaluate - // the filtering predicate. This placement is done in order to have a fixed slot where we can - // find the predicate value, without the need for a further random_access operation just to extract - // this value from the set of predicate operations - let predicate_value = input_values.last().unwrap().to_bool_target(); - let predicate_hash = input_hash.last().unwrap(); - // initialize input_values and input_hash input vectors for basic operation components employed to - // compute the results to be returned for the current row - let mut input_values = column_extraction_wires.input_wires.column_values.to_vec(); - let mut input_hash = column_extraction_wires.column_hash.to_vec(); - // Set of input wires for each of the `MAX_NUM_RESULT_OPS` basic operation components employed to - // compute the results to be returned for the current row - let mut result_value_wires = Vec::with_capacity(MAX_NUM_RESULT_OPS); - for _ in 0..MAX_NUM_RESULT_OPS { - let BasicOperationWires { - input_wires, - output_value, - output_hash, - num_overflows: new_num_overflows, - } = BasicOperationInputs::build(b, &input_values, &input_hash, num_overflows); - // add the output_value computed by the last basic operation component to the input values - // for the next basic operation components employed to compute results for current row - input_values.push(output_value); - // and the corresponding output_hash to the input hash as well - input_hash.push(output_hash); - // update the counter of overflows detected - num_overflows = new_num_overflows; - // add placeholder data to payload for placeholder hash - placeholder_hash_payload.push(input_wires.placeholder_ids[0]); - placeholder_hash_payload - .extend_from_slice(&input_wires.placeholder_values[0].to_targets()); - placeholder_hash_payload.push(input_wires.placeholder_ids[1]); - placeholder_hash_payload - .extend_from_slice(&input_wires.placeholder_values[1].to_targets()); - result_value_wires.push(input_wires); - } - // `possible_output_values` to be provided to output component are the set of `MAX_NUM_COLUMNS` - // and the `MAX_NUM_RESULT_OPS` results of results operations, which are all already accumulated - // in the `input_values` vector - let possible_output_values: [UInt256Target; MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS] = - input_values.try_into().unwrap(); - // same for `possible_output_hash`, all the hashes are already accumulated in the `input_hash` vector - let possible_output_hash: [ComputationalHashTarget; MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS] = - input_hash.try_into().unwrap(); - let output_component_wires = T::build( - b, - possible_output_values, - possible_output_hash, - &predicate_value, - predicate_hash, - ); // compute overflow flag - let not_overflow = b.is_equal(num_overflows, zero); - let overflow = b.not(not_overflow); - let placeholder_hash = b.hash_n_to_hash_no_pad::(placeholder_hash_payload); - let placeholder_hash = QueryBoundTarget::add_query_bounds_to_placeholder_hash( - b, - &min_query, - &max_query, - &placeholder_hash, - ); - // compute output_values to be exposed; we call `pad_slice_to_curve_len` to ensure that the - // first output value is always padded to the size of a `CurveTarget` - let mut output_values = PublicInputs::<_, MAX_NUM_RESULTS>::pad_slice_to_curve_len( - &output_component_wires.first_output_value().to_targets(), - ); - // Append the other `MAX_NUM_RESULTS-1` output values - output_values.extend_from_slice( - &output_component_wires - .other_output_values() - .iter() - .flat_map(|t| t.to_targets()) - .collect_vec(), - ); - // add query bounds to computational hash - let computational_hash = QueryBoundTarget::add_query_bounds_to_computational_hash( - b, - &min_query, - &max_query, - &output_component_wires.computational_hash(), - ); - PublicInputs::::new( + let overflow = b.is_not_equal(value_wires.output_wires.num_overflows, zero); + + let output_values_targets = value_wires.output_wires.values.to_targets(); + + // compute dummy left boundary and right boundary rows to be exposed as public inputs; + // they are ignored by the circuits processing this proof, so it's ok to use dummy + // values + let dummy_boundary_row_targets = + b.constants(&vec![F::ZERO; BoundaryRowDataTarget::NUM_TARGETS]); + let primary_index_value = &value_wires.input_wires.column_values[0]; + PublicInputsUniversalCircuit::::new( &tree_hash.to_targets(), - output_values.as_slice(), - &[predicate_value.target], - output_component_wires.ops_ids(), - &index_value.to_targets(), + &output_values_targets, + &[value_wires.output_wires.count], + hash_wires.agg_ops_ids.as_slice(), + &dummy_boundary_row_targets, + &dummy_boundary_row_targets, + &hash_wires.input_wires.min_query_primary.to_targets(), + &hash_wires.input_wires.max_query_primary.to_targets(), &node_min.to_targets(), - &node_max.to_targets(), - &[*primary_index_id, *second_index_id], - &min_query_value.to_targets(), - &max_query_value.to_targets(), + &primary_index_value.to_targets(), &[overflow.target], - &computational_hash.to_targets(), - &placeholder_hash.to_targets(), + &hash_wires.computational_hash.to_targets(), + &hash_wires.placeholder_hash.to_targets(), ) .register(b); UniversalQueryCircuitWires { - column_extraction_wires: column_extraction_wires.input_wires, is_leaf, - min_query: min_query.into(), - max_query: max_query.into(), - filtering_predicate_ops: filtering_predicate_wires.try_into().unwrap(), - result_value_ops: result_value_wires.try_into().unwrap(), - output_component_wires: output_component_wires.input_wires(), + hash_wires: hash_wires.input_wires, + value_wires: value_wires.input_wires, } } @@ -813,157 +230,9 @@ where T, >, ) { - self.column_extraction_inputs - .assign(pw, &wires.column_extraction_wires); pw.set_bool_target(wires.is_leaf, self.is_leaf); - wires.min_query.assign(pw, &self.min_query); - wires.max_query.assign(pw, &self.max_query); - self.filtering_predicate_inputs - .iter() - .zip(wires.filtering_predicate_ops.iter()) - .for_each(|(inputs, wires)| inputs.assign(pw, wires)); - self.result_values_inputs - .iter() - .zip(wires.result_value_ops.iter()) - .for_each(|(inputs, wires)| inputs.assign(pw, wires)); - self.output_component_inputs - .assign(pw, &wires.output_component_wires); - } - - /// This method returns the ids of the placeholders employed to compute the placeholder hash, - /// in the same order, so that those ids can be provided as input to other circuits that need - /// to recompute this hash - pub(crate) fn ids_for_placeholder_hash(&self) -> Vec { - self.filtering_predicate_inputs - .iter() - .flat_map(|op_inputs| vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]]) - .chain(self.result_values_inputs.iter().flat_map(|op_inputs| { - vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]] - })) - .map(|id| PlaceholderIdentifier::from_fields(&[id])) - .collect_vec() - } - - /// Utility function to compute the `BasicOperationInputs` corresponding to the set of `operations` specified - /// as input. The set of `BasicOperationInputs` is padded to `MAX_NUM_OPS` with dummy operations, which is - /// the expected number of operations expected as input by the circuit. - fn compute_operation_inputs( - operations: &[BasicOperation], - placeholders: &Placeholders, - ) -> Result<[BasicOperationInputs; MAX_NUM_OPS]> { - let dummy_placeholder = dummy_placeholder(placeholders); - // starting offset in the input values provided to basic operation component where the output values - // of `operations` will be found. It is computed as follows since these operations will be placed - // at the end of these functions in the last slots among the `MAX_NUM_OPS` available, as expected - // by the circuit - let start_actual_ops = MAX_NUM_COLUMNS + MAX_NUM_OPS - operations.len(); - let ops_wires = operations.iter().enumerate().map(|(i, op)| { - let mut constant_operand = U256::ZERO; - // the number of input values provided to the basic operation component - // computing the current predicate operation - let num_inputs = start_actual_ops + i; - let mut compute_op_inputs = |is_first_op: bool| { - let operand = if is_first_op { - op.first_operand - } else { - op.second_operand.unwrap_or_default() - }; - Ok( - match operand { - InputOperand::Placeholder(p) => { - let placeholder_value = placeholders.get(&p)?; - ( - Some(placeholder_value), - Some(p), - if is_first_op { - BasicOperationInputs::first_placeholder_offset(num_inputs) - } else { - BasicOperationInputs::second_placeholder_offset(num_inputs) - }, - ) - }, - InputOperand::Constant(val) => { - constant_operand = val; - ( - None, - None, - BasicOperationInputs::constant_operand_offset(num_inputs), - ) - }, - InputOperand::Column(index) => { - ensure!(index < MAX_NUM_COLUMNS, - "column index specified as input for {}-th predicate operation is higher than number of columns", i); - ( - None, - None, - BasicOperationInputs::input_value_offset(index), - ) - }, - InputOperand::PreviousValue(index) => { - ensure!(index < i, - "previous value index specified as input for {}-th predicate operation is higher than the number of values already computed by previous operations", i); - ( - None, - None, - BasicOperationInputs::input_value_offset(start_actual_ops+index), - ) - }, - } - )}; - let (first_placeholder_value, first_placeholder_id, first_selector) = compute_op_inputs( - true - )?; - let (second_placeholder_value, second_placeholder_id, second_selector) = compute_op_inputs( - false - )?; - let placeholder_values = [ - first_placeholder_value.unwrap_or(dummy_placeholder.value), - second_placeholder_value.unwrap_or(dummy_placeholder.value) - ]; - let placeholder_ids = [ - first_placeholder_id.unwrap_or(dummy_placeholder.id).to_field(), - second_placeholder_id.unwrap_or(dummy_placeholder.id).to_field(), - ]; - Ok(BasicOperationInputs { - constant_operand, - placeholder_values, - placeholder_ids, - first_input_selector: F::from_canonical_usize(first_selector), - second_input_selector: F::from_canonical_usize(second_selector), - op_selector: op.op.to_field(), - }) - }).collect::>>()?; - // we pad ops_wires up to `MAX_NUM_OPS` with dummy operations; we pad at - // the beginning of the array since the circuits expects to find the operation computing - // the actual result values as the last of the `MAX_NUM_OPS` operations - Ok(repeat( - // dummy operation - BasicOperationInputs { - constant_operand: U256::ZERO, - placeholder_values: [dummy_placeholder.value, dummy_placeholder.value], - placeholder_ids: [ - dummy_placeholder.id.to_field(), - dummy_placeholder.id.to_field(), - ], - first_input_selector: F::ZERO, - second_input_selector: F::ZERO, - op_selector: Operation::EqOp.to_field(), - }, - ) - .take(MAX_NUM_OPS - operations.len()) - .chain(ops_wires) - .collect_vec() - .try_into() - .unwrap()) - } -} - -/// Placeholder to be employed in the universal circuit as a dummy placeholder -/// in the circuit -fn dummy_placeholder(placeholders: &Placeholders) -> Placeholder { - Placeholder { - value: placeholders.get(&dummy_placeholder_id()).unwrap(), // cannot fail since default placeholder is always associated to a value - id: dummy_placeholder_id(), + self.hash_gadget_inputs.assign(pw, &wires.hash_wires); + self.value_gadget_inputs.assign(pw, &wires.value_wires); } } @@ -1015,7 +284,7 @@ impl< const MAX_NUM_PREDICATE_OPS: usize, const MAX_NUM_RESULT_OPS: usize, const MAX_NUM_RESULTS: usize, - T: OutputComponent, + T: OutputComponent + Serialize + DeserializeOwned, > CircuitLogicWires for UniversalQueryCircuitWires< MAX_NUM_COLUMNS, @@ -1054,6 +323,66 @@ where } } +#[derive(Debug, Serialize, Deserialize)] +pub struct UniversalQueryCircuitParams< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize, +> { + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) data: CircuitData, + wires: UniversalQueryCircuitWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, +} + +impl< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent + Serialize + DeserializeOwned, + > + UniversalQueryCircuitParams< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > +where + [(); MAX_NUM_RESULTS - 1]:, + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + pub(crate) fn build(config: CircuitConfig) -> Self { + let mut builder = CBuilder::new(config); + let wires = UniversalQueryCircuitInputs::build(&mut builder); + let data = builder.build(); + Self { data, wires } + } + + pub(crate) fn generate_proof( + &self, + input: &UniversalQueryCircuitInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + ) -> Result> { + let mut pw = PartialWitness::::new(); + input.assign(&mut pw, &self.wires); + self.data.prove(pw) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] /// Inputs for the 2 variant of universal query circuit pub enum UniversalCircuitInput< @@ -1098,8 +427,8 @@ where [(); MAX_NUM_RESULTS - 1]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { - /// Provide input values for universal circuit variant for queries with aggregation operations - pub(crate) fn new_query_with_agg( + /// Provide input values for universal circuit variant for queries without aggregation operations + pub(crate) fn new_query_no_agg( column_cells: &RowCells, predicate_operations: &[BasicOperation], placeholders: &Placeholders, @@ -1107,7 +436,7 @@ where query_bounds: &QueryBounds, results: &ResultStructure, ) -> Result { - Ok(UniversalCircuitInput::QueryWithAgg( + Ok(UniversalCircuitInput::QueryNoAgg( UniversalQueryCircuitInputs::new( column_cells, predicate_operations, @@ -1115,28 +444,39 @@ where is_leaf, query_bounds, results, + false, )?, )) } - /// Provide input values for universal circuit variant for queries without aggregation operations - pub(crate) fn new_query_no_agg( - column_cells: &RowCells, + + pub(crate) fn ids_for_placeholder_hash( predicate_operations: &[BasicOperation], + results: &ResultStructure, placeholders: &Placeholders, - is_leaf: bool, query_bounds: &QueryBounds, - results: &ResultStructure, - ) -> Result { - Ok(UniversalCircuitInput::QueryNoAgg( - UniversalQueryCircuitInputs::new( - column_cells, - predicate_operations, - placeholders, - is_leaf, - query_bounds, - results, - )?, - )) + ) -> Result<[PlaceholderId; 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]> { + Ok(match results.output_variant { + Output::Aggregation => UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + AggOutputCircuit, + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ), + Output::NoAggregation => UniversalQueryHashInputs::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + NoAggOutputCircuit, + >::ids_for_placeholder_hash( + predicate_operations, results, placeholders, query_bounds + ), + }? + .try_into() + .unwrap()) } } @@ -1148,9 +488,9 @@ mod tests { use itertools::Itertools; use mp2_common::{ array::ToField, + default_config, group_hashing::map_to_curve_point, poseidon::empty_poseidon_hash, - proof::ProofWithVK, utils::{FromFields, ToFields, TryIntoBool}, C, D, F, }; @@ -1170,29 +510,27 @@ mod tests { use rand::{thread_rng, Rng}; use crate::query::{ - aggregation::{QueryBoundSource, QueryBounds}, - api::{CircuitInput, Parameters}, computational_hash_ids::{ AggregationOperation, ColumnIDs, HashPermutation, Identifiers, Operation, PlaceholderIdentifier, }, - public_inputs::PublicInputs, + public_inputs::PublicInputsUniversalCircuit, universal_circuit::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, + output_with_aggregation::Circuit as OutputAggCircuit, universal_circuit_inputs::{ BasicOperation, ColumnCell, InputOperand, OutputItem, PlaceholderId, Placeholders, ResultStructure, RowCells, }, - universal_query_circuit::placeholder_hash, + universal_query_circuit::{ + placeholder_hash, UniversalCircuitInput, UniversalQueryCircuitParams, + }, ComputationalHash, }, + utils::{QueryBoundSource, QueryBounds}, }; - use anyhow::{Error, Result}; - - use super::{ - OutputComponent, UniversalCircuitInput, UniversalQueryCircuitInputs, - UniversalQueryCircuitWires, - }; + use super::{OutputComponent, UniversalQueryCircuitInputs, UniversalQueryCircuitWires}; impl< const MAX_NUM_COLUMNS: usize, @@ -1229,20 +567,8 @@ mod tests { } } - // utility function to locate operation `op` in the set of `previous_ops` - fn locate_previous_operation( - previous_ops: &[BasicOperation], - op: &BasicOperation, - ) -> Result { - previous_ops - .iter() - .find_position(|current_op| *current_op == op) - .map(|(pos, _)| pos) - .ok_or(Error::msg("operation {} not found in set of previous ops")) - } - // test the following query: - // SELECT AVG(C1+C2/(C2*C3)), SUM(C1+C2), MIN(C1+$1), MAX(C4-2), AVG(C5) FROM T WHERE (C5 > 5 AND C1*C3 <= C4+C5 OR C3 == $2) AND C2 >= 75 AND C2 < $3 + // SELECT AVG(C1+C2/(C2*C3)), SUM(C1+C2), MIN(C1+$1), MAX(C4-2), AVG(C5) FROM T WHERE (C5 > 5 AND C1*C3 <= C4+C5 OR C3 == $2) AND C2 >= 75 AND C2 < $3 AND C1 >= 42 AND C1 < 56 async fn query_with_aggregation(build_parameters: bool) { init_logging(); const NUM_ACTUAL_COLUMNS: usize = 5; @@ -1251,17 +577,32 @@ mod tests { const MAX_NUM_RESULT_OPS: usize = 30; const MAX_NUM_RESULTS: usize = 10; let rng = &mut thread_rng(); - let min_query = U256::from(75); - let max_query = U256::from(98); + let min_query_primary = U256::from(42); + let max_query_primary = U256::from(55); + let min_query_secondary = U256::from(75); + let max_query_secondary = U256::from(98); let column_values = (0..NUM_ACTUAL_COLUMNS) .map(|i| { - if i == 1 { - // ensure that second column value is in the range specified by the query: - // we sample a random u256 in range [0, max_query - min_query) and then we - // add min_query - gen_random_u256(rng).div_rem(max_query - min_query).1 + min_query - } else { - gen_random_u256(rng) + match i { + 0 => { + // ensure that primary index column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng) + .div_rem(max_query_primary - min_query_primary + U256::from(1)) + .1 + + min_query_primary + } + 1 => { + // ensure that second column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng) + .div_rem(max_query_secondary - min_query_secondary + U256::from(1)) + .1 + + min_query_secondary + } + _ => gen_random_u256(rng), } }) .collect_vec(); @@ -1279,16 +620,13 @@ mod tests { // define placeholders let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); - let mut placeholders = Placeholders::new_empty( - U256::default(), - U256::default(), // dummy values - ); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); // 3-rd placeholder is the max query bound let third_placeholder_id = PlaceholderId::Generic(2); - placeholders.insert(third_placeholder_id, max_query); + placeholders.insert(third_placeholder_id, max_query_secondary); // build predicate operations let mut predicate_operations = vec![]; @@ -1316,10 +654,12 @@ mod tests { // C1*C3 <= C4 + C5 let expr_comparison = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &column_prod).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &column_prod) + .unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &column_add).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &column_add) + .unwrap(), )), op: Operation::LessThanOrEqOp, }; @@ -1334,10 +674,12 @@ mod tests { // c5_comparison AND expr_comparison let and_comparisons = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &c5_comparison).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &c5_comparison) + .unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &expr_comparison).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &expr_comparison) + .unwrap(), )), op: Operation::AndOp, }; @@ -1345,10 +687,12 @@ mod tests { // final filtering predicate: and_comparisons OR placeholder_eq let predicate = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &and_comparisons).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &and_comparisons) + .unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &placeholder_eq).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &placeholder_eq) + .unwrap(), )), op: Operation::OrOp, }; @@ -1372,10 +716,11 @@ mod tests { // C1 + C2/(C2*C3) let div = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&result_operations, &column_add).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_add).unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&result_operations, &column_prod).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), )), op: Operation::DivOp, }; @@ -1399,15 +744,19 @@ mod tests { // output items are all computed values in this query, expect for the last item // which is a column let output_items = vec![ - OutputItem::ComputedValue(locate_previous_operation(&result_operations, &div).unwrap()), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_add).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &div).unwrap(), + ), + OutputItem::ComputedValue( + BasicOperation::locate_previous_operation(&result_operations, &column_add).unwrap(), ), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_placeholder).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_placeholder) + .unwrap(), ), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_sub_const).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_sub_const) + .unwrap(), ), OutputItem::Column(4), ]; @@ -1431,7 +780,7 @@ mod tests { let query_bounds = QueryBounds::new( &placeholders, - Some(QueryBoundSource::Constant(min_query)), + Some(QueryBoundSource::Constant(min_query_secondary)), Some( QueryBoundSource::Operation(BasicOperation { first_operand: InputOperand::Placeholder(third_placeholder_id), @@ -1443,21 +792,21 @@ mod tests { ), ) .unwrap(); - let min_query_value = query_bounds.min_query_secondary().value; - let max_query_value = query_bounds.max_query_secondary().value; - let input = CircuitInput::< + let circuit = UniversalQueryCircuitInputs::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - >::new_universal_circuit( + OutputAggCircuit, + >::new( &row_cells, &predicate_operations, - &results, &placeholders, is_leaf, &query_bounds, + &results, + false, ) .unwrap(); @@ -1519,15 +868,18 @@ mod tests { }) .collect_vec(); - let circuit = if let CircuitInput::UniversalCircuit(UniversalCircuitInput::QueryWithAgg( - c, - )) = &input - { - c - } else { - unreachable!() - }; - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); let computational_hash = ComputationalHash::from_bytes( @@ -1549,17 +901,14 @@ mod tests { .into(), ); let proof = if build_parameters { - let params = Parameters::build(); - params - .generate_proof(input) - .and_then(|p| ProofWithVK::deserialize(&p)) - .map(|p| p.proof().clone()) - .unwrap() + let params = UniversalQueryCircuitParams::build(default_config()); + params.generate_proof(&circuit).unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = + PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_values[0], pi.first_value_as_u256()); assert_eq!(output_values[1..], pi.values()[..output_values.len() - 1]); @@ -1568,12 +917,10 @@ mod tests { predicate_value, pi.num_matching_rows().try_into_bool().unwrap() ); - assert_eq!(column_values[0], pi.index_value()); - assert_eq!(column_values[1], pi.min_value()); - assert_eq!(column_values[1], pi.max_value()); - assert_eq!([column_ids[0], column_ids[1]], pi.index_ids()); - assert_eq!(min_query_value, pi.min_query_value()); - assert_eq!(max_query_value, pi.max_query_value()); + assert_eq!(min_query_primary, pi.min_primary()); + assert_eq!(max_query_primary, pi.max_primary()); + assert_eq!(column_cells[1].value, pi.secondary_index_value()); + assert_eq!(column_cells[0].value, pi.primary_index_value()); assert_eq!(placeholder_hash, pi.placeholder_hash()); assert_eq!(computational_hash, pi.computational_hash()); assert_eq!(predicate_err || result_err, pi.overflow_flag()); @@ -1590,7 +937,7 @@ mod tests { } // test the following query: - // SELECT C1 < C2/45, C3*C4, C7, (C5-C6)%C1, C3*C4 - $1 FROM T WHERE ((NOT C5 != 42) OR C1*C7 <= C4/C6+C5 XOR C3 < $2) AND C2 >= $3 AND C2 < 44 + // SELECT C1 < C2/45, C3*C4, C7, (C5-C6)%C1, C3*C4 - $1 FROM T WHERE ((NOT C5 != 42) OR C1*C7 <= C4/C6+C5 XOR C3 < $2) AND C2 >= $3 AND C2 < 44 AND C1 > 13 AND C1 <= 17 async fn query_without_aggregation(single_result: bool, build_parameters: bool) { init_logging(); const NUM_ACTUAL_COLUMNS: usize = 7; @@ -1599,20 +946,32 @@ mod tests { const MAX_NUM_RESULT_OPS: usize = 30; const MAX_NUM_RESULTS: usize = 10; let rng = &mut thread_rng(); - let min_query = U256::from(43); - let max_query = U256::from(43); + let min_query_primary = U256::from(14); + let max_query_primary = U256::from(17); + let min_query_secondary = U256::from(43); + let max_query_secondary = U256::from(43); let column_values = (0..NUM_ACTUAL_COLUMNS) .map(|i| { - if i == 1 { - // ensure that second column value is in the range specified by the query: - // we sample a random u256 in range [0, max_query - min_query + 1) and then we - // add min_query - gen_random_u256(rng) - .div_rem(max_query - min_query + U256::from(1)) - .1 - + min_query - } else { - gen_random_u256(rng) + match i { + 0 => { + // ensure that primary index column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng) + .div_rem(max_query_primary - min_query_primary + U256::from(1)) + .1 + + min_query_primary + } + 1 => { + // ensure that second column value is in the range specified by the query: + // we sample a random u256 in range [0, max_query - min_query) and then we + // add min_query + gen_random_u256(rng) + .div_rem(max_query_secondary - min_query_secondary + U256::from(1)) + .1 + + min_query_secondary + } + _ => gen_random_u256(rng), } }) .collect_vec(); @@ -1630,16 +989,13 @@ mod tests { // define placeholders let first_placeholder_id = PlaceholderId::Generic(0); let second_placeholder_id = PlaceholderIdentifier::Generic(1); - let mut placeholders = Placeholders::new_empty( - U256::default(), - U256::default(), // dummy values - ); + let mut placeholders = Placeholders::new_empty(min_query_primary, max_query_primary); [first_placeholder_id, second_placeholder_id] .iter() .for_each(|id| placeholders.insert(*id, gen_random_u256(rng))); // 3-rd placeholder is the min query bound let third_placeholder_id = PlaceholderId::Generic(2); - placeholders.insert(third_placeholder_id, min_query); + placeholders.insert(third_placeholder_id, min_query_secondary); // build predicate operations let mut predicate_operations = vec![]; @@ -1667,7 +1023,8 @@ mod tests { // C4/C6 + C5 let expr_add = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &column_div).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &column_div) + .unwrap(), ), second_operand: Some(InputOperand::Column(4)), op: Operation::AddOp, @@ -1676,10 +1033,12 @@ mod tests { // C1*C7 <= C4/C6 + C5 let expr_comparison = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &column_prod).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &column_prod) + .unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &expr_add).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &expr_add) + .unwrap(), )), op: Operation::LessThanOrEqOp, }; @@ -1694,7 +1053,8 @@ mod tests { predicate_operations.push(placeholder_cmp); let not_c5 = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &c5_comparison).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &c5_comparison) + .unwrap(), ), second_operand: None, op: Operation::NotOp, @@ -1703,10 +1063,11 @@ mod tests { // NOT c5_comparison OR expr_comparison let or_comparisons = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, ¬_c5).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, ¬_c5).unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &expr_comparison).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &expr_comparison) + .unwrap(), )), op: Operation::OrOp, }; @@ -1714,10 +1075,12 @@ mod tests { // final filtering predicate: or_comparisons XOR placeholder_cmp let predicate = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &or_comparisons).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &or_comparisons) + .unwrap(), ), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&predicate_operations, &placeholder_cmp).unwrap(), + BasicOperation::locate_previous_operation(&predicate_operations, &placeholder_cmp) + .unwrap(), )), op: Operation::XorOp, }; @@ -1735,7 +1098,7 @@ mod tests { let column_cmp = BasicOperation { first_operand: InputOperand::Column(0), second_operand: Some(InputOperand::PreviousValue( - locate_previous_operation(&result_operations, &div_const).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &div_const).unwrap(), )), op: Operation::LessThanOp, }; @@ -1757,7 +1120,7 @@ mod tests { // (C5 - C6) % C1 let column_mod = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&result_operations, &column_sub).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_sub).unwrap(), ), second_operand: Some(InputOperand::Column(0)), op: Operation::AddOp, @@ -1766,7 +1129,8 @@ mod tests { // C3*C4 - $1 let sub_placeholder = BasicOperation { first_operand: InputOperand::PreviousValue( - locate_previous_operation(&result_operations, &column_prod).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), ), second_operand: Some(InputOperand::Placeholder(first_placeholder_id)), op: Operation::SubOp, @@ -1778,22 +1142,26 @@ mod tests { // which is a column let output_items = if single_result { vec![OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_cmp).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_cmp).unwrap(), )] } else { vec![ OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_cmp).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_cmp) + .unwrap(), ), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_prod).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_prod) + .unwrap(), ), OutputItem::Column(6), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &column_mod).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &column_mod) + .unwrap(), ), OutputItem::ComputedValue( - locate_previous_operation(&result_operations, &sub_placeholder).unwrap(), + BasicOperation::locate_previous_operation(&result_operations, &sub_placeholder) + .unwrap(), ), ] }; @@ -1811,21 +1179,23 @@ mod tests { let query_bounds = QueryBounds::new( &placeholders, Some(QueryBoundSource::Placeholder(third_placeholder_id)), - Some(QueryBoundSource::Constant(max_query)), + Some(QueryBoundSource::Constant(max_query_secondary)), ) .unwrap(); - let input = CircuitInput::< + let circuit = UniversalQueryCircuitInputs::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_RESULTS, - >::new_universal_circuit( + OutputNoAggCircuit, + >::new( &row_cells, &predicate_operations, - &results, &placeholders, is_leaf, &query_bounds, + &results, + false, ) .unwrap(); @@ -1902,13 +1272,18 @@ mod tests { Point::NEUTRAL }; - let circuit = - if let CircuitInput::UniversalCircuit(UniversalCircuitInput::QueryNoAgg(c)) = &input { - c - } else { - unreachable!() - }; - let placeholder_hash_ids = circuit.ids_for_placeholder_hash(); + let placeholder_hash_ids = UniversalCircuitInput::< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + >::ids_for_placeholder_hash( + &predicate_operations, + &results, + &placeholders, + &query_bounds, + ) + .unwrap(); let placeholder_hash = placeholder_hash(&placeholder_hash_ids, &placeholders, &query_bounds).unwrap(); let computational_hash = ComputationalHash::from_bytes( @@ -1931,17 +1306,14 @@ mod tests { ); let proof = if build_parameters { - let params = Parameters::build(); - params - .generate_proof(input) - .and_then(|p| ProofWithVK::deserialize(&p)) - .map(|p| p.proof().clone()) - .unwrap() + let params = UniversalQueryCircuitParams::build(default_config()); + params.generate_proof(&circuit).unwrap() } else { run_circuit::(circuit.clone()) }; - let pi = PublicInputs::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); + let pi = + PublicInputsUniversalCircuit::<_, MAX_NUM_RESULTS>::from_slice(&proof.public_inputs); assert_eq!(tree_hash, pi.tree_hash()); assert_eq!(output_acc.to_weierstrass(), pi.first_value_as_curve_point()); // The other MAX_NUM_RESULTS -1 output values are dummy ones, as in queries @@ -1964,12 +1336,10 @@ mod tests { predicate_value, pi.num_matching_rows().try_into_bool().unwrap() ); - assert_eq!(column_values[0], pi.index_value()); - assert_eq!(column_values[1], pi.min_value()); - assert_eq!(column_values[1], pi.max_value()); - assert_eq!([column_ids[0], column_ids[1]], pi.index_ids()); - assert_eq!(min_query, pi.min_query_value()); - assert_eq!(max_query, pi.max_query_value()); + assert_eq!(min_query_primary, pi.min_primary()); + assert_eq!(max_query_primary, pi.max_primary()); + assert_eq!(column_cells[1].value, pi.secondary_index_value()); + assert_eq!(column_cells[0].value, pi.primary_index_value()); assert_eq!(placeholder_hash, pi.placeholder_hash()); assert_eq!(computational_hash, pi.computational_hash()); assert_eq!(predicate_err || result_err, pi.overflow_flag()); diff --git a/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs new file mode 100644 index 000000000..268c2adbe --- /dev/null +++ b/verifiable-db/src/query/universal_circuit/universal_query_gadget.rs @@ -0,0 +1,1469 @@ +use std::{ + fmt::Debug, + iter::{once, repeat}, +}; + +use alloy::primitives::U256; +use anyhow::{bail, ensure, Result}; +use itertools::Itertools; +use mp2_common::{ + array::ToField, + poseidon::{empty_poseidon_hash, H}, + serialization::{deserialize, deserialize_long_array, serialize, serialize_long_array}, + types::{CBuilder, CURVE_TARGET_LEN}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256, NUM_LIMBS}, + utils::{FromFields, FromTargets, ToFields, ToTargets}, + CHasher, F, +}; +use plonky2::{ + field::types::Field, + hash::{hash_types::NUM_HASH_OUT_ELTS, hashing::hash_n_to_hash_no_pad}, + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::config::{GenericHashOut, Hasher}, +}; +use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; +use serde::{Deserialize, Serialize}; + +use crate::query::{ + computational_hash_ids::{ + ColumnIDs, ComputationalHashCache, HashPermutation, Operation, Output, + PlaceholderIdentifier, + }, + universal_circuit::{ + basic_operation::BasicOperationInputs, column_extraction::ColumnExtractionValueWires, + universal_circuit_inputs::OutputItem, + }, + utils::{QueryBoundSecondary, QueryBoundSource, QueryBounds}, +}; + +use super::{ + basic_operation::{ + BasicOperationHashWires, BasicOperationInputWires, BasicOperationValueWires, + BasicOperationWires, + }, + column_extraction::{ColumnExtractionInputWires, ColumnExtractionInputs}, + universal_circuit_inputs::{ + BasicOperation, InputOperand, Placeholder, PlaceholderId, Placeholders, ResultStructure, + RowCells, + }, + universal_query_circuit::dummy_placeholder_id, + ComputationalHash, ComputationalHashTarget, MembershipHashTarget, PlaceholderHash, + PlaceholderHashTarget, +}; + +/// Wires representing a query bound in the universal circuit +pub(crate) type QueryBoundTarget = BasicOperationWires; + +/// Input wires for `QueryBoundTarget` (i.e., the wires that need to be assigned) +pub(crate) type QueryBoundTargetInputs = BasicOperationInputWires; + +impl From for QueryBoundTargetInputs { + fn from(value: QueryBoundTarget) -> Self { + value.hash_wires.input_wires + } +} + +impl QueryBoundTarget { + pub(crate) fn new(b: &mut CBuilder) -> Self { + let zero_u256 = b.zero_u256(); + let zero = b.zero(); + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + // The 0 constant provided as input value is used as a dummy operand in case the query bound + // is taken from a constant in the query: in this case, the query bound in the circuit is + // computed with the operation `InputOperand::Constant(query_bound) + input_values[0]`, which + // yields `query_bound` as output since `input_values[0] = 0`. The constant input values 0 is + // associated to the empty hash in the computational hash, which is provided as `input_hash[0]` + BasicOperationInputs::build(b, &[zero_u256], &[empty_hash], zero) + } + + /// Get the actual value of this query bound computed in the circuit + pub(crate) fn get_bound_value(&self) -> &UInt256Target { + &self.value_wires.output_value + } + + // Compute the number of overflows occurred during operations to compute query bounds + pub(crate) fn num_overflows_for_query_bound_operations( + b: &mut CBuilder, + min_query: &Self, + max_query: &Self, + ) -> Target { + b.add( + min_query.value_wires.num_overflows, + max_query.value_wires.num_overflows, + ) + } + + pub(crate) fn add_query_bounds_to_placeholder_hash( + b: &mut CBuilder, + min_query_bound: &Self, + max_query_bound: &Self, + placeholder_hash: &PlaceholderHashTarget, + ) -> PlaceholderHashTarget { + b.hash_n_to_hash_no_pad::( + placeholder_hash + .elements + .iter() + .chain(once( + &min_query_bound.hash_wires.input_wires.placeholder_ids[0], + )) + .chain(&min_query_bound.hash_wires.input_wires.placeholder_values[0].to_targets()) + .chain(once( + &min_query_bound.hash_wires.input_wires.placeholder_ids[1], + )) + .chain(&min_query_bound.hash_wires.input_wires.placeholder_values[1].to_targets()) + .chain(once( + &max_query_bound.hash_wires.input_wires.placeholder_ids[0], + )) + .chain(&max_query_bound.hash_wires.input_wires.placeholder_values[0].to_targets()) + .chain(once( + &max_query_bound.hash_wires.input_wires.placeholder_ids[1], + )) + .chain(&max_query_bound.hash_wires.input_wires.placeholder_values[1].to_targets()) + .cloned() + .collect(), + ) + } + + pub(crate) fn add_query_bounds_to_computational_hash( + b: &mut CBuilder, + min_query_bound: &Self, + max_query_bound: &Self, + computational_hash: &ComputationalHashTarget, + ) -> ComputationalHashTarget { + b.hash_n_to_hash_no_pad::( + computational_hash + .to_targets() + .into_iter() + .chain(min_query_bound.hash_wires.output_hash.to_targets()) + .chain(max_query_bound.hash_wires.output_hash.to_targets()) + .collect_vec(), + ) + } +} + +impl QueryBoundTargetInputs { + pub(crate) fn assign(&self, pw: &mut PartialWitness, bound: &QueryBound) { + bound.operation.assign(pw, self); + } +} +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct QueryBound { + pub(crate) operation: BasicOperationInputs, +} + +impl QueryBound { + /// Number of input values provided to the basic operation component computing the query bounds + /// in the circuit; currently it is 1 since the constant input value 0 is provided as a dummy + /// input value (see QueryBoundTarget::new()). + const NUM_INPUT_VALUES: usize = 1; + + /// Initialize a query bound for the primary index, from the set of `placeholders` employed in the query, + /// which include also the primary index bounds by construction. The flag `is_min_bound` + /// must be true iff the bound to be initialized is a lower bound in the range specified in the query + #[allow(dead_code)] // unused for now, but it could be useful to keep it + pub(crate) fn new_primary_index_bound( + placeholders: &Placeholders, + is_min_bound: bool, + ) -> Result { + let source = QueryBoundSource::Placeholder(if is_min_bound { + PlaceholderIdentifier::MinQueryOnIdx1 + } else { + PlaceholderIdentifier::MaxQueryOnIdx1 + }); + Self::new_bound(placeholders, &source) + } + + /// Initialize a query bound for the secondary index, from the set of placeholders employed in the query + /// and from the provided `bound`, which specifies how the query bound should be computed in the circuit + pub(crate) fn new_secondary_index_bound( + placeholders: &Placeholders, + bound: &QueryBoundSecondary, + ) -> Result { + let source = bound.into(); + Self::new_bound(placeholders, &source) + } + + /// Internal function employed to instantiate a new query bound + fn new_bound(placeholders: &Placeholders, source: &QueryBoundSource) -> Result { + let dummy_placeholder = dummy_placeholder(placeholders); + let op_inputs = match source { + QueryBoundSource::Constant(value) => + // if the query bound is computed from a constant `value`, we instantiate the operation + // `value + input_values[0]` in the circuit, as in `QueryBoundTarget` construction we + // always set `input_values[0] = 0`. This trick allows to get the same constant `value` + // as output of the basic operation employed in the circuit to compute the query bound + { + BasicOperationInputs { + constant_operand: *value, + placeholder_values: [dummy_placeholder.value, dummy_placeholder.value], + placeholder_ids: [ + dummy_placeholder.id.to_field(), + dummy_placeholder.id.to_field(), + ], + first_input_selector: BasicOperationInputs::constant_operand_offset( + Self::NUM_INPUT_VALUES, + ) + .to_field(), + second_input_selector: BasicOperationInputs::input_value_offset(0).to_field(), + op_selector: Operation::AddOp.to_field(), + } + } + QueryBoundSource::Placeholder(id) => + // if the query bound is computed from a placeholder with id `id`, we instantiate + // the operation `$id + 0` in the circuit, which will yield the value of placeholder + // $id (which should correspond to the query bound) as output + { + BasicOperationInputs { + constant_operand: U256::ZERO, + placeholder_values: [placeholders.get(id)?, dummy_placeholder.value], + placeholder_ids: [id.to_field(), dummy_placeholder.id.to_field()], + first_input_selector: BasicOperationInputs::first_placeholder_offset( + Self::NUM_INPUT_VALUES, + ) + .to_field(), + second_input_selector: BasicOperationInputs::constant_operand_offset( + Self::NUM_INPUT_VALUES, + ) + .to_field(), + op_selector: Operation::AddOp.to_field(), + } + } + QueryBoundSource::Operation(op) => { + // In this case we instantiate the basic operation `op`, checking that the operation + // satisfies the requirements for query bound operations (i.e., it involves only + // constant values and placeholders) + let mut constant_operand = U256::ZERO; + let mut process_input_op = |operand: &InputOperand| { + Ok(match operand { + InputOperand::Placeholder(id) => + ( + *id, + None, + ), + InputOperand::Constant(value) => { + constant_operand = *value; + ( + dummy_placeholder.id, + Some(BasicOperationInputs::constant_operand_offset(Self::NUM_INPUT_VALUES)) + ) + }, + _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant"), + }) + }; + + let (first_placeholder_id, first_selector) = process_input_op(&op.first_operand)?; + let (second_placeholder_id, second_selector) = process_input_op( + &op.second_operand.unwrap_or_default(), // Unary operation, so use a dummy operand + )?; + BasicOperationInputs { + constant_operand, + placeholder_values: [ + placeholders.get(&first_placeholder_id)?, + placeholders.get(&second_placeholder_id)?, + ], + placeholder_ids: [ + first_placeholder_id.to_field(), + second_placeholder_id.to_field(), + ], + first_input_selector: first_selector + .unwrap_or(BasicOperationInputs::first_placeholder_offset( + Self::NUM_INPUT_VALUES, + )) + .to_field(), + second_input_selector: second_selector + .unwrap_or(BasicOperationInputs::second_placeholder_offset( + Self::NUM_INPUT_VALUES, + )) + .to_field(), + op_selector: op.op.to_field(), + } + } + }; + Ok(Self { + operation: op_inputs, + }) + } + + /// This method computes the value of a query bound + pub(crate) fn compute_bound_value( + placeholders: &Placeholders, + source: &QueryBoundSource, + ) -> Result<(U256, bool)> { + Ok(match source { + QueryBoundSource::Constant(value) => (*value, false), + QueryBoundSource::Placeholder(id) => (placeholders.get(id)?, false), + QueryBoundSource::Operation(op) => { + let (values, overflow) = + BasicOperation::compute_operations(&[*op], &[], placeholders)?; + (values[0], overflow) + } + }) + } + + /// This method returns the basic operation employed in the circuit for the query bound which is + /// taken fromthe query as specify by the input `source`. It basically returns the same operations + /// that are instantiated in the circuit by the `new_bound` internal method + pub(crate) fn get_basic_operation(source: &QueryBoundSource) -> Result { + Ok(match source { + QueryBoundSource::Constant(value) => + // convert to operation `value + input_value[0]`, which yield value as `input_value[0] = 0` in the circuit + { + BasicOperation { + first_operand: InputOperand::Constant(*value), + second_operand: Some(InputOperand::Column(0)), + op: Operation::AddOp, + } + } + QueryBoundSource::Placeholder(id) => + // convert to operation $id + 0 + { + BasicOperation { + first_operand: InputOperand::Placeholder(*id), + second_operand: Some(InputOperand::Constant(U256::ZERO)), + op: Operation::AddOp, + } + } + QueryBoundSource::Operation(op) => { + // validate operation for query bound + match op.first_operand { + InputOperand::Constant(_) | InputOperand::Placeholder(_) => (), + _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant") + } + if let Some(operand) = op.second_operand { + match operand { + InputOperand::Constant(_) | InputOperand::Placeholder(_) => (), + _ => bail!("Invalid operand for query bound operation: must be either a placeholder or a constant") + } + } + *op + } + }) + } + + pub(crate) fn add_secondary_query_bounds_to_placeholder_hash( + min_query: &Self, + max_query: &Self, + placeholder_hash: &PlaceholderHash, + ) -> PlaceholderHash { + hash_n_to_hash_no_pad::<_, HashPermutation>( + &placeholder_hash + .to_vec() + .into_iter() + .chain(once(min_query.operation.placeholder_ids[0])) + .chain(min_query.operation.placeholder_values[0].to_fields()) + .chain(once(min_query.operation.placeholder_ids[1])) + .chain(min_query.operation.placeholder_values[1].to_fields()) + .chain(once(max_query.operation.placeholder_ids[0])) + .chain(max_query.operation.placeholder_values[0].to_fields()) + .chain(once(max_query.operation.placeholder_ids[1])) + .chain(max_query.operation.placeholder_values[1].to_fields()) + .collect_vec(), + ) + } + + pub(crate) fn add_secondary_query_bounds_to_computational_hash( + min_query: &QueryBoundSource, + max_query: &QueryBoundSource, + computational_hash: &ComputationalHash, + ) -> Result { + let min_query_op = Self::get_basic_operation(min_query)?; + let max_query_op = Self::get_basic_operation(max_query)?; + // initialize computational hash cache with the empty hash associated to the only input value (hardcoded to 0 + // in the circuit) of the basic operation components employed for query bounds + let mut cache = ComputationalHashCache::new_from_column_hash( + Self::NUM_INPUT_VALUES, + &[*empty_poseidon_hash()], + )?; + let min_query_hash = Operation::basic_operation_hash(&mut cache, &[], &min_query_op)?; + let max_query_hash = Operation::basic_operation_hash(&mut cache, &[], &max_query_op)?; + let inputs = computational_hash + .to_vec() + .into_iter() + .chain(min_query_hash.to_fields()) + .chain(max_query_hash.to_fields()) + .collect_vec(); + Ok(H::hash_no_pad(&inputs)) + } +} + +/// Trait for the 2 different variants of output components we currently support +/// in query circuits +pub trait OutputComponent: Clone { + type ValueWires: OutputComponentValueWires; + type HashWires: OutputComponentHashWires; + + fn new(selector: &[F], ids: &[F], num_outputs: usize) -> Result; + + #[cfg(test)] // used only in test for now + fn build( + b: &mut CBuilder, + possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], + possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], + predicate_value: &BoolTarget, + predicate_hash: &ComputationalHashTarget, + ) -> OutputComponentWires { + let hash_wires: >::HashWires = + Self::build_hash(b, possible_output_hash, predicate_hash); + let value_wires = Self::build_values( + b, + possible_output_values, + predicate_value, + &hash_wires.input_wires(), + ); + + OutputComponentWires { + value_wires, + hash_wires, + } + } + + fn build_values( + b: &mut CBuilder, + possible_output_values: [UInt256Target; NUM_OUTPUT_VALUES], + predicate_value: &BoolTarget, + input_wires: &::InputWires, + ) -> Self::ValueWires; + + fn build_hash( + b: &mut CBuilder, + possible_output_hash: [ComputationalHashTarget; NUM_OUTPUT_VALUES], + predicate_hash: &ComputationalHashTarget, + ) -> Self::HashWires; + + fn assign( + &self, + pw: &mut PartialWitness, + wires: &::InputWires, + ); + + /// Return the type of output component, specified as an instance of `Output` enum + fn output_variant() -> Output; +} +/// Trait representing the wires related to the output values computed +/// by an output component implementation +pub trait OutputComponentValueWires: Clone + Debug { + /// Associated type specifying the type of the first output value computed by this output + /// component; this type varies depending on the particular component: + /// - It is a `CurveTarget` in the output component for queries without aggregation operations + /// - It is a `UInt256Target` in the output for queries with aggregation operations + type FirstT: ToTargets; + + /// Get the first output value returned by the output component; this is accessed by an ad-hoc + /// method since such output value could be a `UInt256Target` or a `CurveTarget`, depending + /// on the output component instance + fn first_output_value(&self) -> Self::FirstT; + /// Get the subsequent output values returned by the output component + fn other_output_values(&self) -> &[UInt256Target]; +} + +/// Trait representing the input/output wires related to the computational hash +/// computed by an output component implementation +pub trait OutputComponentHashWires: Clone + Debug + Eq + PartialEq { + /// Input wires of the output component + type InputWires: Serialize + for<'a> Deserialize<'a> + Clone + Debug + Eq + PartialEq; + + /// Get the identifiers of the aggregation operations specified in the query to aggregate the + /// results (e.g., `SUM`, `AVG`) + fn ops_ids(&self) -> &[Target]; + /// Get the computational hash returned by the output component + fn computational_hash(&self) -> ComputationalHashTarget; + /// Get the input wires for the output component + fn input_wires(&self) -> Self::InputWires; +} + +/// Wires representing an output component +#[cfg(test)] // used only in test for now +pub struct OutputComponentWires< + ValueWires: OutputComponentValueWires, + HashWires: OutputComponentHashWires, +> { + pub(crate) value_wires: ValueWires, + pub(crate) hash_wires: HashWires, +} +/// Wires for the universal query hash gadget +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct UniversalQueryHashInputWires< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, +> { + /// Input wires for column extraction component + pub(crate) column_extraction_wires: ColumnExtractionInputWires, + /// Lower bound of the range for the primary index specified in the query + pub(crate) min_query_primary: UInt256Target, + /// Upper bound of the range for the primary index specified in the query + pub(crate) max_query_primary: UInt256Target, + /// Lower bound of the range for the secondary index specified in the query + min_query_secondary: QueryBoundTargetInputs, + /// Upper bound of the range for the secondary index specified in the query + max_query_secondary: QueryBoundTargetInputs, + /// Input wires for the `MAX_NUM_PREDICATE_OPS` basic operation components necessary + /// to evaluate the filtering predicate + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + filtering_predicate_ops: [BasicOperationInputWires; MAX_NUM_PREDICATE_OPS], + /// Input wires for the `MAX_NUM_RESULT_OPS` basic operation components necessary + /// to compute the results for the current row + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + result_value_ops: [BasicOperationInputWires; MAX_NUM_RESULT_OPS], + /// Input wires for the `MAX_NUM_RESULTS` output components that computes the + /// output values for the current row + output_component_wires: ::InputWires, +} + +#[derive(Clone, Debug)] +pub(crate) struct UniversalQueryHashWires< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, +> { + pub(crate) input_wires: UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + pub(crate) computational_hash: ComputationalHashTarget, + pub(crate) placeholder_hash: PlaceholderHashTarget, + pub(crate) min_secondary: UInt256Target, + pub(crate) max_secondary: UInt256Target, + pub(crate) num_bound_overflows: Target, + pub(crate) agg_ops_ids: [Target; MAX_NUM_RESULTS], +} +/// Input values for the universal query hash gadget +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct UniversalQueryHashInputs< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, +> { + column_extraction_inputs: ColumnExtractionInputs, + min_query_primary: U256, + max_query_primary: U256, + min_query_secondary: QueryBound, + max_query_secondary: QueryBound, + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + filtering_predicate_inputs: [BasicOperationInputs; MAX_NUM_PREDICATE_OPS], + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + result_values_inputs: [BasicOperationInputs; MAX_NUM_RESULT_OPS], + output_component_inputs: T, +} + +impl< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + T: OutputComponent, + > + UniversalQueryHashInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > +where + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, +{ + /// Instantiate `Self` from the necessary inputs. Note that the following assumption is expected on the + /// structure of the inputs: + /// The output of the last operation in `predicate_operations` will be taken as the filtering predicate evaluation; + /// this is an assumption exploited in the circuit for efficiency, and it is a simple assumption to be required for + /// the caller of this method + pub(crate) fn new( + column_ids: &ColumnIDs, + predicate_operations: &[BasicOperation], + placeholders: &Placeholders, + query_bounds: &QueryBounds, + results: &ResultStructure, + ) -> Result { + let num_columns = column_ids.num_columns(); + ensure!( + num_columns <= MAX_NUM_COLUMNS, + "number of columns is higher than the maximum value allowed" + ); + let padded_column_ids = column_ids + .to_vec() + .into_iter() + .chain(repeat(F::NEG_ONE)) + .take(MAX_NUM_COLUMNS) + .collect_vec(); + let column_extraction_inputs = ColumnExtractionInputs:: { + real_num_columns: num_columns, + column_ids: padded_column_ids.try_into().unwrap(), + }; + + let num_predicate_ops = predicate_operations.len(); + ensure!(num_predicate_ops <= MAX_NUM_PREDICATE_OPS, + "Number of operations to compute filtering predicate is higher than the maximum number allowed"); + let num_result_ops = results.result_operations.len(); + ensure!( + num_result_ops <= MAX_NUM_RESULT_OPS, + "Number of operations to compute results is higher than the maximum number allowed" + ); + let predicate_ops_inputs = Self::compute_operation_inputs::( + predicate_operations, + placeholders, + )?; + let result_ops_inputs = Self::compute_operation_inputs::( + &results.result_operations, + placeholders, + )?; + let selectors = results.output_items.iter().enumerate().map(|(i, item)| { + Ok( + match item { + OutputItem::Column(index) => { + ensure!(*index < MAX_NUM_COLUMNS, + "Column index provided as {}-th output value is higher than the maximum number of columns", i); + F::from_canonical_usize(*index) + }, + OutputItem::ComputedValue(index) => { + ensure!(*index < num_result_ops, + "an operation computing an output results not found in set of result operations"); + // the output will be placed in the `num_result_ops - index` last slot in the set of + // `possible_output_values` provided as input in the circuit to the output component, + // i.e., the input array found in `OutputComponent::build` method. + // Therefore, since the `possible_output_values` array in the circuit has + // `MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS` entries, the selector for such output value + // can be computed as the length of `possible_output_values.len() - (num_result_ops - index)`, + // which correspond to the `num_result_ops - index`-th entry from the end of the array + F::from_canonical_usize(MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS - (num_result_ops - *index)) + }, + }) + }).collect::>>()?; + let output_component_inputs = + T::new(&selectors, &results.output_ids, results.output_ids.len())?; + + let min_query = QueryBound::new_secondary_index_bound( + placeholders, + query_bounds.min_query_secondary(), + )?; + + let max_query = QueryBound::new_secondary_index_bound( + placeholders, + query_bounds.max_query_secondary(), + )?; + + Ok(Self { + column_extraction_inputs, + min_query_primary: query_bounds.min_query_primary(), + max_query_primary: query_bounds.max_query_primary(), + min_query_secondary: min_query, + max_query_secondary: max_query, + filtering_predicate_inputs: predicate_ops_inputs, + result_values_inputs: result_ops_inputs, + output_component_inputs, + }) + } + + pub(crate) fn build( + b: &mut CBuilder, + ) -> UniversalQueryHashWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + > { + let column_extraction_wires = ColumnExtractionInputs::build_hash(b); + let [min_query_primary, max_query_primary] = b.add_virtual_u256_arr_unsafe(); + let min_query_secondary = QueryBoundTarget::new(b); + let max_query_secondary = QueryBoundTarget::new(b); + let mut input_hash = column_extraction_wires.column_hash.to_vec(); + // Payload to compute the placeholder hash public input + let mut placeholder_hash_payload = vec![]; + // Set of input wires for each of the `MAX_NUM_PREDICATE_OPS` basic operation components employed to + // evaluate the filtering predicate + let mut filtering_predicate_wires = Vec::with_capacity(MAX_NUM_PREDICATE_OPS); + for _ in 0..MAX_NUM_PREDICATE_OPS { + let BasicOperationHashWires { + input_wires, + output_hash, + } = BasicOperationInputs::build_hash(b, &input_hash); + // add the output_hash computed by the last basic operation component to the input hashes + // for the next basic operation components employed to evaluate the filtering predicate + input_hash.push(output_hash); + // add placeholder data to payload for placeholder hash + placeholder_hash_payload.push(input_wires.placeholder_ids[0]); + placeholder_hash_payload + .extend_from_slice(&input_wires.placeholder_values[0].to_targets()); + placeholder_hash_payload.push(input_wires.placeholder_ids[1]); + placeholder_hash_payload + .extend_from_slice(&input_wires.placeholder_values[1].to_targets()); + filtering_predicate_wires.push(input_wires); + } + // Place the computational hash of the evaluation of the filtering predicate in `predicate_hash` + // variable; the evaluation and the corresponding hash are expected to be the output of the + // last basic operation component among the `MAX_NUM_PREDICATE_OPS` ones employed to evaluate + // the filtering predicate. This placement is done in order to have a fixed slot where we can + // find the predicate hash, without the need for a further random_access operation just to extract + // this hash from the set of predicate operations + let predicate_hash = input_hash.last().unwrap(); + let mut input_hash = column_extraction_wires.column_hash.to_vec(); + // Set of input wires for each of the `MAX_NUM_RESULT_OPS` basic operation components employed to + // compute the result values for the current row + let mut result_value_wires = Vec::with_capacity(MAX_NUM_RESULT_OPS); + for _ in 0..MAX_NUM_RESULT_OPS { + let BasicOperationHashWires { + input_wires, + output_hash, + } = BasicOperationInputs::build_hash(b, &input_hash); + // add the output_hash computed by the last basic operation component to the input hashes + // for the next basic operation components employed to compute result values for the current row + input_hash.push(output_hash); + // add placeholder data to payload for placeholder hash + placeholder_hash_payload.push(input_wires.placeholder_ids[0]); + placeholder_hash_payload + .extend_from_slice(&input_wires.placeholder_values[0].to_targets()); + placeholder_hash_payload.push(input_wires.placeholder_ids[1]); + placeholder_hash_payload + .extend_from_slice(&input_wires.placeholder_values[1].to_targets()); + result_value_wires.push(input_wires); + } + // `possible_output_hash` to be provided to output component are the set of `MAX_NUM_COLUMNS` + // and the `MAX_NUM_RESULT_OPS` computational hash of results operations, which are all already + // accumulated in the `input_hash` vector + let possible_output_hash: [ComputationalHashTarget; MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS] = + input_hash.try_into().unwrap(); + + let output_component_wires = T::build_hash(b, possible_output_hash, predicate_hash); + let placeholder_hash = b.hash_n_to_hash_no_pad::(placeholder_hash_payload); + let placeholder_hash = QueryBoundTarget::add_query_bounds_to_placeholder_hash( + b, + &min_query_secondary, + &max_query_secondary, + &placeholder_hash, + ); + // add query bounds to computational hash + let computational_hash = QueryBoundTarget::add_query_bounds_to_computational_hash( + b, + &min_query_secondary, + &max_query_secondary, + &output_component_wires.computational_hash(), + ); + + let min_secondary = min_query_secondary.get_bound_value().clone(); + let max_secondary = max_query_secondary.get_bound_value().clone(); + let num_bound_overflows = QueryBoundTarget::num_overflows_for_query_bound_operations( + b, + &min_query_secondary, + &max_query_secondary, + ); + UniversalQueryHashWires { + input_wires: UniversalQueryHashInputWires { + column_extraction_wires: column_extraction_wires.input_wires, + min_query_primary, + max_query_primary, + min_query_secondary: min_query_secondary.into(), + max_query_secondary: max_query_secondary.into(), + filtering_predicate_ops: filtering_predicate_wires.try_into().unwrap(), + result_value_ops: result_value_wires.try_into().unwrap(), + output_component_wires: output_component_wires.input_wires(), + }, + computational_hash, + placeholder_hash, + min_secondary, + max_secondary, + num_bound_overflows, + agg_ops_ids: output_component_wires + .ops_ids() + .to_vec() + .try_into() + .unwrap(), + } + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + ) { + self.column_extraction_inputs + .assign(pw, &wires.column_extraction_wires); + pw.set_u256_target(&wires.min_query_primary, self.min_query_primary); + pw.set_u256_target(&wires.max_query_primary, self.max_query_primary); + wires + .min_query_secondary + .assign(pw, &self.min_query_secondary); + wires + .max_query_secondary + .assign(pw, &self.max_query_secondary); + self.filtering_predicate_inputs + .iter() + .chain(self.result_values_inputs.iter()) + .zip( + wires + .filtering_predicate_ops + .iter() + .chain(wires.result_value_ops.iter()), + ) + .for_each(|(value, target)| value.assign(pw, target)); + self.output_component_inputs + .assign(pw, &wires.output_component_wires); + } + + /// This method returns the ids of the placeholders employed to compute the placeholder hash, + /// in the same order, so that those ids can be provided as input to other circuits that need + /// to recompute this hash + pub(crate) fn ids_for_placeholder_hash( + predicate_operations: &[BasicOperation], + results: &ResultStructure, + placeholders: &Placeholders, + query_bounds: &QueryBounds, + ) -> Result> { + let hash_input_gadget = Self::new( + &ColumnIDs::default(), + predicate_operations, + placeholders, + query_bounds, + results, + )?; + Ok(hash_input_gadget + .filtering_predicate_inputs + .iter() + .flat_map(|op_inputs| vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]]) + .chain( + hash_input_gadget + .result_values_inputs + .iter() + .flat_map(|op_inputs| { + vec![op_inputs.placeholder_ids[0], op_inputs.placeholder_ids[1]] + }), + ) + .map(|id| PlaceholderIdentifier::from_fields(&[id])) + .collect_vec()) + } + + /// Utility function to compute the `BasicOperationInputs` corresponding to the set of `operations` specified + /// as input. The set of `BasicOperationInputs` is padded to `MAX_NUM_OPS` with dummy operations, which is + /// the expected number of operations expected as input by the circuit. + pub(crate) fn compute_operation_inputs( + operations: &[BasicOperation], + placeholders: &Placeholders, + ) -> Result<[BasicOperationInputs; MAX_NUM_OPS]> { + let dummy_placeholder = dummy_placeholder(placeholders); + // starting offset in the input values provided to basic operation component where the output values + // of `operations` will be found. It is computed as follows since these operations will be placed + // at the end of these functions in the last slots among the `MAX_NUM_OPS` available, as expected + // by the circuit + let start_actual_ops = MAX_NUM_COLUMNS + MAX_NUM_OPS - operations.len(); + let ops_wires = operations.iter().enumerate().map(|(i, op)| { + let mut constant_operand = U256::ZERO; + // the number of input values provided to the basic operation component + // computing the current predicate operation + let num_inputs = start_actual_ops + i; + let mut compute_op_inputs = |is_first_op: bool| { + let operand = if is_first_op { + op.first_operand + } else { + op.second_operand.unwrap_or_default() + }; + Ok( + match operand { + InputOperand::Placeholder(p) => { + let placeholder_value = placeholders.get(&p)?; + ( + Some(placeholder_value), + Some(p), + if is_first_op { + BasicOperationInputs::first_placeholder_offset(num_inputs) + } else { + BasicOperationInputs::second_placeholder_offset(num_inputs) + }, + ) + }, + InputOperand::Constant(val) => { + constant_operand = val; + ( + None, + None, + BasicOperationInputs::constant_operand_offset(num_inputs), + ) + }, + InputOperand::Column(index) => { + ensure!(index < MAX_NUM_COLUMNS, + "column index specified as input for {}-th predicate operation is higher than number of columns", i); + ( + None, + None, + BasicOperationInputs::input_value_offset(index), + ) + }, + InputOperand::PreviousValue(index) => { + ensure!(index < i, + "previous value index specified as input for {}-th predicate operation is higher than the number of values already computed by previous operations", i); + ( + None, + None, + BasicOperationInputs::input_value_offset(start_actual_ops+index), + ) + }, + } + )}; + let (first_placeholder_value, first_placeholder_id, first_selector) = compute_op_inputs( + true + )?; + let (second_placeholder_value, second_placeholder_id, second_selector) = compute_op_inputs( + false + )?; + let placeholder_values = [ + first_placeholder_value.unwrap_or(dummy_placeholder.value), + second_placeholder_value.unwrap_or(dummy_placeholder.value) + ]; + let placeholder_ids = [ + first_placeholder_id.unwrap_or(dummy_placeholder.id).to_field(), + second_placeholder_id.unwrap_or(dummy_placeholder.id).to_field(), + ]; + Ok(BasicOperationInputs { + constant_operand, + placeholder_values, + placeholder_ids, + first_input_selector: F::from_canonical_usize(first_selector), + second_input_selector: F::from_canonical_usize(second_selector), + op_selector: op.op.to_field(), + }) + }).collect::>>()?; + // we pad ops_wires up to `MAX_NUM_OPS` with dummy operations; we pad at + // the beginning of the array since the circuits expects to find the operation computing + // the actual result values as the last of the `MAX_NUM_OPS` operations + Ok(repeat( + // dummy operation + BasicOperationInputs { + constant_operand: U256::ZERO, + placeholder_values: [dummy_placeholder.value, dummy_placeholder.value], + placeholder_ids: [ + dummy_placeholder.id.to_field(), + dummy_placeholder.id.to_field(), + ], + first_input_selector: F::ZERO, + second_input_selector: F::ZERO, + op_selector: Operation::EqOp.to_field(), + }, + ) + .take(MAX_NUM_OPS - operations.len()) + .chain(ops_wires) + .collect_vec() + .try_into() + .unwrap()) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct CurveOrU256([T; CURVE_TARGET_LEN]); + +impl CurveOrU256 { + pub(crate) fn from_slice(t: &[T]) -> Self { + Self( + t.iter() + .cloned() + .chain(repeat(t[0].clone())) + .take(CURVE_TARGET_LEN) + .collect_vec() + .try_into() + .unwrap(), + ) + } + + pub(crate) fn to_u256_raw(&self) -> &[T] { + &self.0[..NUM_LIMBS] + } + + pub(crate) fn to_vec(&self) -> Vec { + self.0.to_vec() + } +} + +pub(crate) type CurveOrU256Target = CurveOrU256; + +impl CurveOrU256Target { + pub(crate) fn as_curve_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.0.as_slice()) + } + + pub(crate) fn as_u256_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_u256_raw()) + } +} + +impl FromTargets for CurveOrU256Target { + const NUM_TARGETS: usize = CurveTarget::NUM_TARGETS; + + fn from_targets(t: &[Target]) -> Self { + Self::from_slice(t) + } +} + +impl ToTargets for CurveOrU256Target { + fn to_targets(&self) -> Vec { + self.0.to_vec() + } +} + +#[derive(Clone, Debug)] +pub(crate) struct OutputValuesTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) first_output: CurveOrU256Target, + pub(crate) other_outputs: [UInt256Target; MAX_NUM_RESULTS - 1], +} + +impl OutputValuesTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) fn value_target_at_index(&self, i: usize) -> UInt256Target { + if i == 0 { + self.first_output.as_u256_target() + } else { + self.other_outputs[i - 1].clone() + } + } + + #[cfg(test)] // used only in test for now + pub(crate) fn build(b: &mut CBuilder) -> Self { + let first_output = CurveOrU256(b.add_virtual_target_arr()); + let other_outputs = b.add_virtual_u256_arr(); + + Self { + first_output, + other_outputs, + } + } + + #[cfg(test)] // used only in test for now + pub(crate) fn set_target( + &self, + pw: &mut PartialWitness, + inputs: &OutputValues, + ) { + pw.set_target_arr(&self.first_output.0, &inputs.first_output.0); + pw.set_u256_target_arr(&self.other_outputs, &inputs.other_outputs); + } +} + +impl ToTargets for OutputValuesTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + fn to_targets(&self) -> Vec { + self.first_output + .to_targets() + .into_iter() + .chain(self.other_outputs.iter().flat_map(|out| out.to_targets())) + .collect() + } +} + +impl FromTargets for OutputValuesTarget +where + [(); MAX_NUM_RESULTS - 1]:, +{ + const NUM_TARGETS: usize = + CurveTarget::NUM_TARGETS + (MAX_NUM_RESULTS - 1) * UInt256Target::NUM_TARGETS; + + fn from_targets(t: &[Target]) -> Self { + assert!(t.len() >= Self::NUM_TARGETS); + let first_output = CurveOrU256Target::from_targets(&t[..CurveTarget::NUM_TARGETS]); + let other_outputs = t[CurveTarget::NUM_TARGETS..] + .chunks(UInt256Target::NUM_TARGETS) + .map(UInt256Target::from_targets) + .take(MAX_NUM_RESULTS - 1) + .collect_vec() + .try_into() + .unwrap(); + + Self { + first_output, + other_outputs, + } + } +} +#[derive(Clone, Debug)] +pub(crate) struct OutputValues +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) first_output: CurveOrU256, + pub(crate) other_outputs: [U256; MAX_NUM_RESULTS - 1], +} + +impl OutputValues +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) fn new_aggregation_outputs(values: &[U256]) -> Self { + let first_output = CurveOrU256::::from_slice(&values[0].to_fields()); + let other_outputs = values[1..] + .iter() + .copied() + .chain(repeat(U256::ZERO)) + .take(MAX_NUM_RESULTS - 1) + .collect_vec(); + + Self { + first_output, + other_outputs: other_outputs.try_into().unwrap(), + } + } + + pub(crate) fn new_outputs_no_aggregation(point: &plonky2_ecgfp5::curve::curve::Point) -> Self { + let first_output = CurveOrU256::::from_slice(&point.to_fields()); + Self { + first_output, + other_outputs: [U256::ZERO; MAX_NUM_RESULTS - 1], + } + } + + pub(crate) fn first_value_as_curve_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(&self.first_output.0) + } + + pub(crate) fn first_value_as_u256(&self) -> U256 { + let fields = self.first_output.to_u256_raw(); + U256::from_fields(fields) + } + + /// Return the value as a UInt256 at the specified index + pub(crate) fn value_at_index(&self, i: usize) -> U256 { + if i == 0 { + self.first_value_as_u256() + } else { + self.other_outputs[i - 1] + } + } +} + +impl FromFields for OutputValues +where + [(); MAX_NUM_RESULTS - 1]:, +{ + fn from_fields(t: &[F]) -> Self { + let first_output = CurveOrU256::from_slice(&t[..CURVE_TARGET_LEN]); + let other_outputs = t[CURVE_TARGET_LEN..] + .chunks(NUM_LIMBS) + .map(U256::from_fields) + .take(MAX_NUM_RESULTS - 1) + .collect_vec() + .try_into() + .unwrap(); + + Self { + first_output, + other_outputs, + } + } +} + +impl ToFields for OutputValues +where + [(); MAX_NUM_RESULTS - 1]:, +{ + fn to_fields(&self) -> Vec { + self.first_output + .to_vec() + .into_iter() + .chain(self.other_outputs.iter().flat_map(|out| out.to_fields())) + .collect() + } +} +/// Input wires for the universal query value gadget +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct UniversalQueryValueInputWires { + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + pub(crate) column_values: [UInt256Target; MAX_NUM_COLUMNS], + // flag specifying whether this is a non-dummy row + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) is_non_dummy_row: BoolTarget, +} + +#[derive(Clone, Debug)] +pub(crate) struct UniversalQueryOutputWires +where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) tree_hash: MembershipHashTarget, + pub(crate) values: OutputValuesTarget, + pub(crate) count: Target, + pub(crate) num_overflows: Target, +} + +impl FromTargets for UniversalQueryOutputWires +where + [(); MAX_NUM_RESULTS - 1]:, +{ + const NUM_TARGETS: usize = NUM_HASH_OUT_ELTS + 2 + OutputValuesTarget::NUM_TARGETS; + fn from_targets(t: &[Target]) -> Self { + assert!(t.len() >= Self::NUM_TARGETS); + Self { + tree_hash: MembershipHashTarget::from_vec(t[..NUM_HASH_OUT_ELTS].to_vec()), + values: OutputValuesTarget::from_targets(&t[NUM_HASH_OUT_ELTS..]), + count: t[Self::NUM_TARGETS - 2], + num_overflows: t[Self::NUM_TARGETS - 1], + } + } +} + +impl ToTargets for UniversalQueryOutputWires +where + [(); MAX_NUM_RESULTS - 1]:, +{ + fn to_targets(&self) -> Vec { + self.tree_hash + .to_targets() + .into_iter() + .chain(self.values.to_targets()) + .chain([self.count, self.num_overflows]) + .collect() + } +} + +#[derive(Clone, Debug)] +pub(crate) struct UniversalQueryValueWires< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_RESULTS: usize, +> where + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) input_wires: UniversalQueryValueInputWires, + pub(crate) output_wires: UniversalQueryOutputWires, +} +/// Input values for the universal query value gadget +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct UniversalQueryValueInputs< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, +> { + #[serde( + serialize_with = "serialize_long_array", + deserialize_with = "deserialize_long_array" + )] + pub(crate) column_values: [U256; MAX_NUM_COLUMNS], + pub(crate) is_dummy_row: bool, +} + +impl< + const MAX_NUM_COLUMNS: usize, + const MAX_NUM_PREDICATE_OPS: usize, + const MAX_NUM_RESULT_OPS: usize, + const MAX_NUM_RESULTS: usize, + > + UniversalQueryValueInputs< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + > +where + [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, + [(); MAX_NUM_RESULTS - 1]:, +{ + pub(crate) fn new(row_cells: &RowCells, is_dummy_row: bool) -> Result { + let num_columns = row_cells.num_columns(); + ensure!( + num_columns <= MAX_NUM_COLUMNS, + "number of columns is higher than the maximum value allowed" + ); + let column_cells = row_cells.to_cells(); + let padded_column_values = column_cells + .iter() + .map(|cell| cell.value) + .chain(repeat(U256::ZERO)) + .take(MAX_NUM_COLUMNS) + .collect_vec(); + Ok(Self { + column_values: padded_column_values.try_into().unwrap(), + is_dummy_row, + }) + } + + pub(crate) fn build>( + b: &mut CBuilder, + hash_input_wires: &UniversalQueryHashInputWires< + MAX_NUM_COLUMNS, + MAX_NUM_PREDICATE_OPS, + MAX_NUM_RESULT_OPS, + MAX_NUM_RESULTS, + T, + >, + min_secondary: &UInt256Target, + max_secondary: &UInt256Target, + num_overflows: &Target, + ) -> UniversalQueryValueWires { + let column_values = ColumnExtractionInputs::build_column_values(b); + let _true = b._true(); + // allocate dummy row flag only if we aren't in universal circuit, i.e., if min_primary.is_some() is true + let is_non_dummy_row = b.add_virtual_bool_target_safe(); + let ColumnExtractionValueWires { tree_hash } = ColumnExtractionInputs::build_tree_hash( + b, + &column_values, + &hash_input_wires.column_extraction_wires, + ); + + // Enforce that the value of primary index for the current row is in the range given by these bounds + let index_value = &column_values[0]; + let less_than_max = + b.is_less_or_equal_than_u256(index_value, &hash_input_wires.max_query_primary); + let greater_than_min = + b.is_less_or_equal_than_u256(&hash_input_wires.min_query_primary, index_value); + b.connect(less_than_max.target, _true.target); + b.connect(greater_than_min.target, _true.target); + + // min and max for secondary indexed column + let node_min = &column_values[1]; + let node_max = node_min; + // determine whether the value of second indexed column for the current record is in + // the range specified by the query + let less_than_max = b.is_less_or_equal_than_u256(node_max, max_secondary); + let greater_than_min = b.is_less_or_equal_than_u256(min_secondary, node_min); + let is_in_range = b.and(less_than_max, greater_than_min); + + // initialize input_values vectors for basic operation components employed to + // evaluate the filtering predicate + let mut input_values = column_values.to_vec(); + let mut num_overflows = *num_overflows; + for i in 0..MAX_NUM_PREDICATE_OPS { + let BasicOperationValueWires { + output_value, + num_overflows: new_num_overflows, + } = BasicOperationInputs::build_values( + b, + &input_values, + &hash_input_wires.filtering_predicate_ops[i], + num_overflows, + ); + // add the output_value computed by the last basic operation component to the input values + // for the next basic operation components employed to evaluate the filtering predicate + input_values.push(output_value); + // update the counter of overflows detected + num_overflows = new_num_overflows; + } + // Place the evaluation of the filtering predicate in `predicate_value` variable; the evaluation and + // the corresponding hash are expected to be the output of the last basic operation component among + // the `MAX_NUM_PREDICATE_OPS` ones employed to evaluate the filtering predicate. This placement is + // done in order to have a fixed slot where we can find the predicate value, without the need for a + // further random_access operation just to extract this value from the set of predicate operations + let predicate_value = input_values.last().unwrap().to_bool_target(); + // filtering predicate must be false if the secondary index value for the current row is not in the + // range specified by the query + let predicate_value = b.and(predicate_value, is_in_range); + // filtering predicate must be false also if this is a dummy row + let predicate_value = b.and(predicate_value, is_non_dummy_row); + + // initialize input_values vectors for basic operation components employed to + // compute results values for current row + let mut input_values = column_values.to_vec(); + for i in 0..MAX_NUM_RESULT_OPS { + let BasicOperationValueWires { + output_value, + num_overflows: new_num_overflows, + } = BasicOperationInputs::build_values( + b, + &input_values, + &hash_input_wires.result_value_ops[i], + num_overflows, + ); + // add the output_value computed by the last basic operation component to the input values + // for the next basic operation components employed to evaluate the filtering predicate + input_values.push(output_value); + // update the counter of overflows detected + num_overflows = new_num_overflows; + } + + // `possible_output_values` to be provided to output component are the set of `MAX_NUM_COLUMNS` + // and the `MAX_NUM_RESULT_OPS` results of results operations, which are all already accumulated + // in the `input_values` vector + let possible_output_values: [UInt256Target; MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS] = + input_values.try_into().unwrap(); + + let output_component_value_wires = T::build_values( + b, + possible_output_values, + &predicate_value, + &hash_input_wires.output_component_wires, + ); + + // compute output_values to be exposed; we build the first output value as a `CurveOrU256Target` + let first_output = CurveOrU256Target::from_targets( + &output_component_value_wires + .first_output_value() + .to_targets(), + ); + // Append the other `MAX_NUM_RESULTS-1` output values + let output_values = OutputValuesTarget { + first_output, + other_outputs: output_component_value_wires + .other_output_values() + .to_vec() + .try_into() + .unwrap(), + }; + + // ensure that `num_overflows` is always 0 in case of dummy rows + let num_overflows = b.mul(num_overflows, is_non_dummy_row.target); + + UniversalQueryValueWires { + input_wires: UniversalQueryValueInputWires { + column_values, + is_non_dummy_row, + }, + output_wires: UniversalQueryOutputWires { + tree_hash, + values: output_values, + count: predicate_value.target, + num_overflows, + }, + } + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &UniversalQueryValueInputWires, + ) { + pw.set_u256_target_arr(&wires.column_values, &self.column_values); + pw.set_bool_target(wires.is_non_dummy_row, !self.is_dummy_row); + } +} + +/// Placeholder to be employed in the universal circuit as a dummy placeholder +/// in the circuit +fn dummy_placeholder(placeholders: &Placeholders) -> Placeholder { + Placeholder { + value: placeholders.get(&dummy_placeholder_id()).unwrap(), // cannot fail since default placeholder is always associated to a value + id: dummy_placeholder_id(), + } +} diff --git a/verifiable-db/src/query/aggregation/mod.rs b/verifiable-db/src/query/utils.rs similarity index 72% rename from verifiable-db/src/query/aggregation/mod.rs rename to verifiable-db/src/query/utils.rs index 431c94276..29e217664 100644 --- a/verifiable-db/src/query/aggregation/mod.rs +++ b/verifiable-db/src/query/utils.rs @@ -1,40 +1,40 @@ -use std::iter::once; +use std::{array, iter::once}; use alloy::primitives::U256; use anyhow::Result; use itertools::Itertools; use mp2_common::{ poseidon::{empty_poseidon_hash, HashPermutation}, - proof::ProofWithVK, - serialization::{deserialize_long_array, serialize_long_array}, - types::HashOutput, - utils::{Fieldable, ToFields}, - F, + serialization::{ + deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, + serialize_long_array, + }, + types::{CBuilder, HashOutput}, + u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, + utils::{Fieldable, ToFields, ToTargets}, + CHasher, F, }; use plonky2::{ - hash::{hash_types::HashOut, hashing::hash_n_to_hash_no_pad}, + hash::{ + hash_types::{HashOut, HashOutTarget}, + hashing::hash_n_to_hash_no_pad, + }, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, plonk::config::GenericHashOut, }; use serde::{Deserialize, Serialize}; -pub(crate) mod child_proven_single_path_node; -pub(crate) mod embedded_tree_proven_single_path_node; -pub(crate) mod full_node_index_leaf; -pub(crate) mod full_node_with_one_child; -pub(crate) mod full_node_with_two_children; -pub(crate) mod non_existence_inter; -mod output_computation; -pub(crate) mod partial_node; -mod utils; - use super::{ - api::CircuitInput, computational_hash_ids::{ColumnIDs, Identifiers, PlaceholderIdentifier}, universal_circuit::{ universal_circuit_inputs::{BasicOperation, PlaceholderId, Placeholders, ResultStructure}, universal_query_circuit::{ - placeholder_hash, placeholder_hash_without_query_bounds, QueryBound, + placeholder_hash, placeholder_hash_without_query_bounds, UniversalCircuitInput, }, + universal_query_gadget::QueryBound, ComputationalHash, PlaceholderHash, }, }; @@ -216,132 +216,102 @@ impl NodeInfo { ) } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -/// enum to specify whether a node is the left or right child of another node -pub enum ChildPosition { - Left, - Right, -} - -impl ChildPosition { - // convert `self` to a flag specifying whether a node is the left child of another node or not - pub(crate) fn to_flag(self) -> bool { - match self { - ChildPosition::Left => true, - ChildPosition::Right => false, - } - } -} #[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct CommonInputs { - pub(crate) is_rows_tree_node: bool, - pub(crate) min_query: U256, - pub(crate) max_query: U256, +pub(crate) struct NodeInfoTarget { + /// The hash of the embedded tree at this node. It can be the hash of the row tree if this node is a node in + /// the index tree, or it can be a hash of the cells tree if this node is a node in a rows tree + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) embedded_tree_hash: HashOutTarget, + /// Hashes of the children of the current node, first left child and then right child hash. The hash of left/right child + /// is the empty hash (i.e., H("")) if there is no corresponding left/right child for the current node + #[serde( + serialize_with = "serialize_array", + deserialize_with = "deserialize_array" + )] + pub(crate) child_hashes: [HashOutTarget; 2], + /// value stored in the node. It can be a primary index value if the node is a node in the index tree, + /// a secondary index value if the node is a node in a rows tree + pub(crate) value: UInt256Target, + /// minimum value associated to the current node. It can be a primary index value if the node is a node in the index tree, + /// a secondary index value if the node is a node in a rows tree + pub(crate) min: UInt256Target, + /// minimum value associated to the current node. It can be a primary index value if the node is a node in the index tree, + /// a secondary index value if the node is a node in a rows tree + pub(crate) max: UInt256Target, } -impl CommonInputs { - pub(crate) fn new(is_rows_tree_node: bool, query_bounds: &QueryBounds) -> Self { +impl NodeInfoTarget { + #[allow(dead_code)] + pub(crate) fn build(b: &mut CBuilder) -> Self { + let [value, min, max] = b.add_virtual_u256_arr(); + let [left_child_hash, right_child_hash, embedded_tree_hash] = + array::from_fn(|_| b.add_virtual_hash()); Self { - is_rows_tree_node, - min_query: if is_rows_tree_node { - query_bounds.min_query_secondary.value - } else { - query_bounds.min_query_primary - }, - max_query: if is_rows_tree_node { - query_bounds.max_query_secondary.value - } else { - query_bounds.max_query_primary - }, + embedded_tree_hash, + child_hashes: [left_child_hash, right_child_hash], + value, + min, + max, } } -} -/// Input data structure for circuits employed for nodes where both the children and the embedded tree are proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct TwoProvenChildNodeInput { - /// Proof for the left child of the node being proven - pub(crate) left_child_proof: ProofWithVK, - /// Proof for the right child of the node being proven - pub(crate) right_child_proof: ProofWithVK, - /// Proof for the embedded tree stored in the current node - pub(crate) embedded_tree_proof: ProofWithVK, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, -} -/// Input data structure for circuits employed for nodes where one child and the embedded tree are proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct OneProvenChildNodeInput { - /// Data related to the child not associated with a proof, if any - pub(crate) unproven_child: Option, - /// Proof for the proven child - pub(crate) proven_child_proof: ChildProof, - /// Proof for the embedded tree stored in the current node - pub(crate) embedded_tree_proof: ProofWithVK, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, -} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -/// Data structure representing a proof for a child node -pub struct ChildProof { - /// Actual proof - pub(crate) proof: ProofWithVK, - /// Flag specifying whether the child associated with `proof` is the left or right child of its parent - pub(crate) child_position: ChildPosition, -} -impl ChildProof { - pub fn new(proof: Vec, child_position: ChildPosition) -> Result { - Ok(Self { - proof: ProofWithVK::deserialize(&proof)?, - child_position, - }) + /// Build an instance of `Self` without range-check the `UInt256Target`s + pub(crate) fn build_unsafe(b: &mut CBuilder) -> Self { + let [value, min, max] = b.add_virtual_u256_arr_unsafe(); + let [left_child_hash, right_child_hash, embedded_tree_hash] = + array::from_fn(|_| b.add_virtual_hash()); + Self { + embedded_tree_hash, + child_hashes: [left_child_hash, right_child_hash], + value, + min, + max, + } } -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -/// Enum employed to specify whether a proof refers to a child node or the embedded tree stored in a node -pub enum SubProof { - /// Proof refer to a child - Child(ChildProof), - /// Proof refer to the embedded tree stored in the node: can be either the proof for a single row - /// (if proving a rows tree node) of the proof for the root node of a rows tree (if proving an index tree node) - Embedded(ProofWithVK), -} -impl SubProof { - /// Initialize a new `SubProof::Child` - pub fn new_child_proof(proof: Vec, child_position: ChildPosition) -> Result { - Ok(SubProof::Child(ChildProof::new(proof, child_position)?)) + pub(crate) fn compute_node_hash(&self, b: &mut CBuilder, index_id: Target) -> HashOutTarget { + let inputs = self.child_hashes[0] + .to_targets() + .into_iter() + .chain(self.child_hashes[1].to_targets()) + .chain(self.min.to_targets()) + .chain(self.max.to_targets()) + .chain(once(index_id)) + .chain(self.value.to_targets()) + .chain(self.embedded_tree_hash.to_targets()) + .collect_vec(); + b.hash_n_to_hash_no_pad::(inputs) } - /// Initialize a new `SubProof::Embedded` - pub fn new_embedded_tree_proof(proof: Vec) -> Result { - Ok(SubProof::Embedded(ProofWithVK::deserialize(&proof)?)) + pub(crate) fn set_target(&self, pw: &mut PartialWitness, inputs: &NodeInfo) { + [ + (self.embedded_tree_hash, inputs.embedded_tree_hash), + (self.child_hashes[0], inputs.child_hashes[0]), + (self.child_hashes[1], inputs.child_hashes[1]), + ] + .into_iter() + .for_each(|(target, value)| pw.set_hash_target(target, value)); + pw.set_u256_target_arr( + &[self.min.clone(), self.max.clone(), self.value.clone()], + &[inputs.min, inputs.max, inputs.value], + ); } } -/// Input data structure for circuits employed for nodes where only one among children node and embedded tree is proven -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct SinglePathInput { - /// Data about the left child of the node being proven, if any - pub(crate) left_child: Option, - /// Data about the right child of the node being proven, if any - pub(crate) right_child: Option, - /// Data about the node being proven - pub(crate) node_info: NodeInfo, - /// Proof of either a child node or of the embedded tree stored in the current node - pub(crate) subtree_proof: SubProof, - /// Common inputs shared across all the circuits - pub(crate) common: CommonInputs, +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +/// enum to specify whether a node is the left or right child of another node +pub enum ChildPosition { + Left, + Right, } /// Data structure containing the computational hash and placeholder hash to be provided as input to /// non-existence circuits. These hashes are computed from the query specific data provided as input /// to the initialization method of this data structure pub struct QueryHashNonExistenceCircuits { - computational_hash: ComputationalHash, - placeholder_hash: PlaceholderHash, + pub(crate) computational_hash: ComputationalHash, + pub(crate) placeholder_hash: PlaceholderHash, } impl QueryHashNonExistenceCircuits { @@ -381,7 +351,7 @@ impl QueryHashNonExistenceCircuits { .into(), ) }; - let placeholder_hash_ids = CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -446,23 +416,30 @@ pub struct NonExistenceInput { pub(crate) mod tests { use crate::query::{ computational_hash_ids::{AggregationOperation, Identifiers}, - public_inputs::PublicInputs, + public_inputs::PublicInputsQueryCircuits, + universal_circuit::universal_query_gadget::{CurveOrU256, OutputValues}, }; use alloy::primitives::U256; - use mp2_common::{array::ToField, group_hashing::add_curve_point, utils::ToFields, F}; + use itertools::Itertools; + use mp2_common::{ + array::ToField, + group_hashing::add_curve_point, + utils::{FromFields, ToFields}, + F, + }; use plonky2_ecgfp5::curve::curve::Point; - /// Compute the output values and the overflow number at the specified index by - /// the proofs. It's the test function corresponding to `compute_output_item`. - pub(crate) fn compute_output_item_value( + /// Aggregate the i-th output values found in `outputs` according to the aggregation operation + /// with identifier `op`. It's the test function corresponding to `OutputValuesTarget::aggregate_outputs` + pub(crate) fn aggregate_output_values( i: usize, - proofs: &[&PublicInputs], + outputs: &[OutputValues], + op: F, ) -> (Vec, u32) where [(); S - 1]:, { - let proof0 = &proofs[0]; - let op = proof0.operation_ids()[i]; + let out0 = &outputs[0]; let [op_id, op_min, op_max, op_sum, op_avg] = [ AggregationOperation::IdOp, @@ -479,22 +456,17 @@ pub(crate) mod tests { let is_op_sum = op == op_sum; let is_op_avg = op == op_avg; - // Check that the all proofs are employing the same aggregation operation. - proofs[1..] - .iter() - .for_each(|p| assert_eq!(p.operation_ids()[i], op)); - // Compute the SUM, MIN or MAX value. let mut sum_overflow = 0; - let mut output = proof0.value_at_index(i); + let mut output = out0.value_at_index(i); if i == 0 && is_op_id { // If it's the first proof and the operation is ID, // the value is a curve point not a Uint256. output = U256::ZERO; } - for p in proofs[1..].iter() { + for out in outputs[1..].iter() { // Get the current proof value. - let mut value = p.value_at_index(i); + let mut value = out.value_at_index(i); if i == 0 && is_op_id { // If it's the first proof and the operation is ID, // the value is a curve point not a Uint256. @@ -520,14 +492,14 @@ pub(crate) mod tests { if i == 0 { // We always accumulate order-agnostic digest of the proofs for the first item. output = if is_op_id { - let points: Vec<_> = proofs + let points: Vec<_> = outputs .iter() - .map(|p| Point::decode(p.first_value_as_curve_point().encode()).unwrap()) + .map(|out| Point::decode(out.first_value_as_curve_point().encode()).unwrap()) .collect(); add_curve_point(&points).to_fields() } else { // Pad the current output to ``CURVE_TARGET_LEN` for the first item. - PublicInputs::<_, S>::pad_slice_to_curve_len(&output) + CurveOrU256::from_slice(&output).to_vec() }; } @@ -541,4 +513,29 @@ pub(crate) mod tests { (output, overflow) } + + /// Compute the output values and the overflow number at the specified index by + /// the proofs. It's the test function corresponding to `compute_output_item`. + pub(crate) fn compute_output_item_value( + i: usize, + proofs: &[&PublicInputsQueryCircuits], + ) -> (Vec, u32) + where + [(); S - 1]:, + { + let proof0 = &proofs[0]; + let op = proof0.operation_ids()[i]; + + // Check that the all proofs are employing the same aggregation operation. + proofs[1..] + .iter() + .for_each(|p| assert_eq!(p.operation_ids()[i], op)); + + let outputs = proofs + .iter() + .map(|p| OutputValues::from_fields(p.to_values_raw())) + .collect_vec(); + + aggregate_output_values(i, &outputs, op) + } } diff --git a/verifiable-db/src/results_tree/binding/binding_results.rs b/verifiable-db/src/results_tree/binding/binding_results.rs index 9431af03c..b177a37f3 100644 --- a/verifiable-db/src/results_tree/binding/binding_results.rs +++ b/verifiable-db/src/results_tree/binding/binding_results.rs @@ -3,12 +3,12 @@ use crate::{ query::{ computational_hash_ids::{AggregationOperation, ResultIdentifier}, - public_inputs::PublicInputs as QueryProofPI, universal_circuit::ComputationalHashTarget, }, results_tree::{ binding::public_inputs::PublicInputs, construction::public_inputs::PublicInputs as ResultsConstructionProofPI, + old_public_inputs::PublicInputs as QueryProofPI, }, }; use mp2_common::{ @@ -99,12 +99,14 @@ impl BindingResultsCircuit { mod tests { use super::*; use crate::{ - query::pi_len as query_pi_len, - results_tree::construction::{ - public_inputs::ResultsConstructionPublicInputs, - tests::{pi_len, random_results_construction_public_inputs}, + results_tree::{ + construction::{ + public_inputs::ResultsConstructionPublicInputs, + tests::{pi_len, random_results_construction_public_inputs}, + }, + tests::random_aggregation_public_inputs, }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; use itertools::Itertools; use mp2_common::{poseidon::H, utils::ToFields, C, D, F}; @@ -117,7 +119,7 @@ mod tests { const S: usize = 20; - const QUERY_PI_LEN: usize = query_pi_len::(); + const QUERY_PI_LEN: usize = QueryProofPI::::total_len(); const RESULTS_CONSTRUCTION_PI_LEN: usize = pi_len::(); #[derive(Clone, Debug)] diff --git a/verifiable-db/src/results_tree/construction/results_tree_with_duplicates.rs b/verifiable-db/src/results_tree/construction/results_tree_with_duplicates.rs index 190e1c0a0..c11164e40 100644 --- a/verifiable-db/src/results_tree/construction/results_tree_with_duplicates.rs +++ b/verifiable-db/src/results_tree/construction/results_tree_with_duplicates.rs @@ -11,7 +11,7 @@ use mp2_common::{ }, types::CBuilder, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{SelectHashBuilder, ToTargets}, + utils::{HashBuilder, ToTargets}, D, F, }; use plonky2::{ diff --git a/verifiable-db/src/results_tree/construction/results_tree_without_duplicates.rs b/verifiable-db/src/results_tree/construction/results_tree_without_duplicates.rs index 7fc2860fa..e2f8dc09a 100644 --- a/verifiable-db/src/results_tree/construction/results_tree_without_duplicates.rs +++ b/verifiable-db/src/results_tree/construction/results_tree_without_duplicates.rs @@ -11,7 +11,7 @@ use mp2_common::{ }, types::CBuilder, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{SelectHashBuilder, ToTargets}, + utils::{HashBuilder, ToTargets}, D, F, }; use plonky2::{ diff --git a/verifiable-db/src/results_tree/mod.rs b/verifiable-db/src/results_tree/mod.rs index 53396f41a..443a8fd90 100644 --- a/verifiable-db/src/results_tree/mod.rs +++ b/verifiable-db/src/results_tree/mod.rs @@ -1,2 +1,82 @@ pub(crate) mod binding; pub(crate) mod construction; +/// Old query public inputs, moved here because the circuits in this module still expects +/// these public inputs for now +pub(crate) mod old_public_inputs; + +#[cfg(test)] +pub(crate) mod tests { + use std::array; + + use mp2_common::{array::ToField, types::CURVE_TARGET_LEN, utils::ToFields, F}; + use plonky2::{ + field::types::{Field, Sample}, + hash::hash_types::NUM_HASH_OUT_ELTS, + }; + use plonky2_ecgfp5::curve::curve::Point; + use rand::{thread_rng, Rng}; + + use crate::query::computational_hash_ids::{AggregationOperation, Identifiers}; + + use super::old_public_inputs::{PublicInputs, QueryPublicInputs}; + + /// Generate S number of proof public input slices by the specified operations for testing. + /// The each returned proof public inputs could be constructed by + /// `PublicInputs::from_slice` function. + pub fn random_aggregation_public_inputs( + ops: &[F; S], + ) -> [Vec; N] { + let [ops_range, overflow_range, index_ids_range, c_hash_range, p_hash_range] = [ + QueryPublicInputs::OpIds, + QueryPublicInputs::Overflow, + QueryPublicInputs::IndexIds, + QueryPublicInputs::ComputationalHash, + QueryPublicInputs::PlaceholderHash, + ] + .map(PublicInputs::::to_range); + + let first_value_start = + PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; + let is_first_op_id = + ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); + + // Generate the index ids, computational hash and placeholder hash, + // they should be same for a series of public inputs. + let mut rng = thread_rng(); + let index_ids = (0..2).map(|_| rng.gen()).collect::>().to_fields(); + let [computational_hash, placeholder_hash]: [Vec<_>; 2] = array::from_fn(|_| { + (0..NUM_HASH_OUT_ELTS) + .map(|_| rng.gen()) + .collect::>() + .to_fields() + }); + + array::from_fn(|_| { + let mut pi = (0..PublicInputs::::total_len()) + .map(|_| rng.gen()) + .collect::>() + .to_fields(); + + // Copy the specified operations to the proofs. + pi[ops_range.clone()].copy_from_slice(ops); + + // Set the overflow flag to a random boolean. + let overflow = F::from_bool(rng.gen()); + pi[overflow_range.clone()].copy_from_slice(&[overflow]); + + // Set the index ids, computational hash and placeholder hash, + pi[index_ids_range.clone()].copy_from_slice(&index_ids); + pi[c_hash_range.clone()].copy_from_slice(&computational_hash); + pi[p_hash_range.clone()].copy_from_slice(&placeholder_hash); + + // If the first operation is ID, set the value to a random point. + if is_first_op_id { + let first_value = Point::sample(&mut rng).to_weierstrass().to_fields(); + pi[first_value_start..first_value_start + CURVE_TARGET_LEN] + .copy_from_slice(&first_value); + } + + pi + }) + } +} diff --git a/verifiable-db/src/results_tree/old_public_inputs.rs b/verifiable-db/src/results_tree/old_public_inputs.rs new file mode 100644 index 000000000..7f6d07b00 --- /dev/null +++ b/verifiable-db/src/results_tree/old_public_inputs.rs @@ -0,0 +1,539 @@ +use std::iter::once; + +use alloy::primitives::U256; +use itertools::Itertools; +use mp2_common::{ + public_inputs::{PublicInputCommon, PublicInputRange}, + types::{CBuilder, CURVE_TARGET_LEN}, + u256::{UInt256Target, NUM_LIMBS}, + utils::{FromFields, FromTargets, TryIntoBool}, + F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + iop::target::{BoolTarget, Target}, +}; +use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; + +use crate::query::universal_circuit::universal_query_gadget::{ + CurveOrU256Target, OutputValues, OutputValuesTarget, +}; + +/// Query circuits public inputs +pub enum QueryPublicInputs { + /// `H`: Hash of the tree + TreeHash, + /// `V`: Set of `S` values representing the cumulative results of the query, where`S` is a parameter + /// specifying the maximum number of cumulative results we support; + /// the first value could be either a `u256` or a `CurveTarget`, depending on the query, and so we always + /// represent this value with `CURVE_TARGET_LEN` elements; all the other `S-1` values are always `u256` + OutputValues, + /// `count`: `F` Number of matching records in the query + NumMatching, + /// `ops` : `[F; S]` Set of identifiers of the aggregation operations for each of the `S` items found in `V` + /// (like "SUM", "MIN", "MAX", "COUNT" operations) + OpIds, + /// `I` : `u256` value of the indexed column for the given node (meaningful only for rows tree nodes) + IndexValue, + /// `min` : `u256` Minimum value of the indexed column among all the records stored in the subtree rooted + /// in the current node; values of secondary indexed column are employed for rows tree nodes, + /// while values of primary indexed column are employed for index tree nodes + MinValue, + /// `max`` : Maximum value of the indexed column among all the records stored in the subtree rooted + /// in the current node; values of secondary indexed column are employed for rows tree nodes, + /// while values of primary indexed column are employed for index tree nodes + MaxValue, + /// `index_ids`` : `[2]F` Identifiers of indexed columns + IndexIds, + /// `MIN_I`: `u256` Lower bound of the range of indexed column values specified in the query + MinQuery, + /// `MAX_I`: `u256` Upper bound of the range of indexed column values specified in the query + MaxQuery, + /// `overflow` : `bool` Flag specifying whether an overflow error has occurred in arithmetic + Overflow, + /// `C`: computational hash + ComputationalHash, + /// `H_p` : placeholder hash + PlaceholderHash, +} + +#[derive(Clone, Debug)] +pub struct PublicInputs<'a, T, const S: usize> { + h: &'a [T], + v: &'a [T], + ops: &'a [T], + count: &'a T, + i: &'a [T], + min: &'a [T], + max: &'a [T], + ids: &'a [T], + min_q: &'a [T], + max_q: &'a [T], + overflow: &'a T, + ch: &'a [T], + ph: &'a [T], +} + +const NUM_PUBLIC_INPUTS: usize = QueryPublicInputs::PlaceholderHash as usize + 1; + +impl<'a, T: Clone, const S: usize> PublicInputs<'a, T, S> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(QueryPublicInputs::TreeHash), + Self::to_range(QueryPublicInputs::OutputValues), + Self::to_range(QueryPublicInputs::NumMatching), + Self::to_range(QueryPublicInputs::OpIds), + Self::to_range(QueryPublicInputs::IndexValue), + Self::to_range(QueryPublicInputs::MinValue), + Self::to_range(QueryPublicInputs::MaxValue), + Self::to_range(QueryPublicInputs::IndexIds), + Self::to_range(QueryPublicInputs::MinQuery), + Self::to_range(QueryPublicInputs::MaxQuery), + Self::to_range(QueryPublicInputs::Overflow), + Self::to_range(QueryPublicInputs::ComputationalHash), + Self::to_range(QueryPublicInputs::PlaceholderHash), + ]; + + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Tree hash + NUM_HASH_OUT_ELTS, + // Output values + CURVE_TARGET_LEN + NUM_LIMBS * (S - 1), + // Number of matching records + 1, + // Operation identifiers + S, + // Index column value + NUM_LIMBS, + // Minimum indexed column value + NUM_LIMBS, + // Maximum indexed column value + NUM_LIMBS, + // Indexed column IDs + 2, + // Lower bound for indexed column specified in query + NUM_LIMBS, + // Upper bound for indexed column specified in query + NUM_LIMBS, + // Overflow flag + 1, + // Computational hash + NUM_HASH_OUT_ELTS, + // Placeholder hash + NUM_HASH_OUT_ELTS, + ]; + + pub const fn to_range(query_pi: QueryPublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = query_pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; + } + offset..offset + Self::SIZES[pi_pos] + } + + pub(crate) const fn total_len() -> usize { + Self::to_range(QueryPublicInputs::PlaceholderHash).end + } + + pub(crate) fn to_hash_raw(&self) -> &[T] { + self.h + } + + pub(crate) fn to_values_raw(&self) -> &[T] { + self.v + } + + pub(crate) fn to_count_raw(&self) -> &T { + self.count + } + + pub(crate) fn to_ops_raw(&self) -> &[T] { + self.ops + } + + pub(crate) fn to_index_value_raw(&self) -> &[T] { + self.i + } + + pub(crate) fn to_min_value_raw(&self) -> &[T] { + self.min + } + + pub(crate) fn to_max_value_raw(&self) -> &[T] { + self.max + } + + pub(crate) fn to_index_ids_raw(&self) -> &[T] { + self.ids + } + + pub(crate) fn to_min_query_raw(&self) -> &[T] { + self.min_q + } + + pub(crate) fn to_max_query_raw(&self) -> &[T] { + self.max_q + } + + pub(crate) fn to_overflow_raw(&self) -> &T { + self.overflow + } + + pub(crate) fn to_computational_hash_raw(&self) -> &[T] { + self.ch + } + + pub(crate) fn to_placeholder_hash_raw(&self) -> &[T] { + self.ph + } + + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "input slice too short to build query public inputs, must be at least {} elements", + Self::total_len() + ); + Self { + h: &input[Self::PI_RANGES[0].clone()], + v: &input[Self::PI_RANGES[1].clone()], + count: &input[Self::PI_RANGES[2].clone()][0], + ops: &input[Self::PI_RANGES[3].clone()], + i: &input[Self::PI_RANGES[4].clone()], + min: &input[Self::PI_RANGES[5].clone()], + max: &input[Self::PI_RANGES[6].clone()], + ids: &input[Self::PI_RANGES[7].clone()], + min_q: &input[Self::PI_RANGES[8].clone()], + max_q: &input[Self::PI_RANGES[9].clone()], + overflow: &input[Self::PI_RANGES[10].clone()][0], + ch: &input[Self::PI_RANGES[11].clone()], + ph: &input[Self::PI_RANGES[12].clone()], + } + } + #[allow(clippy::too_many_arguments)] + pub fn new( + h: &'a [T], + v: &'a [T], + count: &'a [T], + ops: &'a [T], + i: &'a [T], + min: &'a [T], + max: &'a [T], + ids: &'a [T], + min_q: &'a [T], + max_q: &'a [T], + overflow: &'a [T], + ch: &'a [T], + ph: &'a [T], + ) -> Self { + Self { + h, + v, + count: &count[0], + ops, + i, + min, + max, + ids, + min_q, + max_q, + overflow: &overflow[0], + ch, + ph, + } + } + + pub fn to_vec(&self) -> Vec { + self.h + .iter() + .chain(self.v.iter()) + .chain(once(self.count)) + .chain(self.ops.iter()) + .chain(self.i.iter()) + .chain(self.min.iter()) + .chain(self.max.iter()) + .chain(self.ids.iter()) + .chain(self.min_q.iter()) + .chain(self.max_q.iter()) + .chain(once(self.overflow)) + .chain(self.ch.iter()) + .chain(self.ph.iter()) + .cloned() + .collect_vec() + } +} + +impl PublicInputCommon for PublicInputs<'_, Target, S> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.v); + cb.register_public_input(*self.count); + cb.register_public_inputs(self.ops); + cb.register_public_inputs(self.i); + cb.register_public_inputs(self.min); + cb.register_public_inputs(self.max); + cb.register_public_inputs(self.ids); + cb.register_public_inputs(self.min_q); + cb.register_public_inputs(self.max_q); + cb.register_public_input(*self.overflow); + cb.register_public_inputs(self.ch); + cb.register_public_inputs(self.ph); + } +} + +impl PublicInputs<'_, Target, S> { + pub fn tree_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + /// Return the first output value as a `CurveTarget` + pub fn first_value_as_curve_target(&self) -> CurveTarget { + let targets = self.to_values_raw(); + CurveOrU256Target::from_targets(targets).as_curve_target() + } + + /// Return the first output value as a `UInt256Target` + pub fn first_value_as_u256_target(&self) -> UInt256Target { + let targets = self.to_values_raw(); + CurveOrU256Target::from_targets(targets).as_u256_target() + } + + /// Return the `UInt256` targets for the last `S-1` values + pub fn values_target(&self) -> [UInt256Target; S - 1] { + OutputValuesTarget::from_targets(self.to_values_raw()).other_outputs + } + + /// Return the value as a `UInt256Target` at the specified index + pub fn value_target_at_index(&self, i: usize) -> UInt256Target + where + [(); S - 1]:, + { + OutputValuesTarget::from_targets(self.to_values_raw()).value_target_at_index(i) + } + + pub fn num_matching_rows_target(&self) -> Target { + *self.to_count_raw() + } + + pub fn operation_ids_target(&self) -> [Target; S] { + self.to_ops_raw().try_into().unwrap() + } + + pub fn index_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_index_value_raw()) + } + + pub fn min_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_value_raw()) + } + + pub fn max_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_value_raw()) + } + + pub fn index_ids_target(&self) -> [Target; 2] { + self.to_index_ids_raw().try_into().unwrap() + } + + pub fn min_query_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_min_query_raw()) + } + + pub fn max_query_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.to_max_query_raw()) + } + + pub fn overflow_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.to_overflow_raw()) + } + + pub fn computational_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn placeholder_hash_target(&self) -> HashOutTarget { + HashOutTarget::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } +} + +impl PublicInputs<'_, F, S> +where + [(); S - 1]:, +{ + pub fn tree_hash(&self) -> HashOut { + HashOut::try_from(self.to_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn first_value_as_curve_point(&self) -> WeierstrassPoint { + OutputValues::::from_fields(self.to_values_raw()).first_value_as_curve_point() + } + + pub fn first_value_as_u256(&self) -> U256 { + OutputValues::::from_fields(self.to_values_raw()).first_value_as_u256() + } + + pub fn values(&self) -> [U256; S - 1] { + OutputValues::::from_fields(self.to_values_raw()).other_outputs + } + + /// Return the value as a UInt256 at the specified index + pub fn value_at_index(&self, i: usize) -> U256 + where + [(); S - 1]:, + { + OutputValues::::from_fields(self.to_values_raw()).value_at_index(i) + } + + pub fn num_matching_rows(&self) -> F { + *self.to_count_raw() + } + + pub fn operation_ids(&self) -> [F; S] { + self.to_ops_raw().try_into().unwrap() + } + + pub fn index_value(&self) -> U256 { + U256::from_fields(self.to_index_value_raw()) + } + + pub fn min_value(&self) -> U256 { + U256::from_fields(self.to_min_value_raw()) + } + + pub fn max_value(&self) -> U256 { + U256::from_fields(self.to_max_value_raw()) + } + + pub fn index_ids(&self) -> [F; 2] { + self.to_index_ids_raw().try_into().unwrap() + } + + pub fn min_query_value(&self) -> U256 { + U256::from_fields(self.to_min_query_raw()) + } + + pub fn max_query_value(&self) -> U256 { + U256::from_fields(self.to_max_query_raw()) + } + + pub fn overflow_flag(&self) -> bool { + (*self.to_overflow_raw()) + .try_into_bool() + .expect("overflow flag public input different from 0 or 1") + } + + pub fn computational_hash(&self) -> HashOut { + HashOut::try_from(self.to_computational_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } + + pub fn placeholder_hash(&self) -> HashOut { + HashOut::try_from(self.to_placeholder_hash_raw()).unwrap() // safe to unwrap as we know the slice has correct length + } +} + +#[cfg(test)] +mod tests { + + use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; + use plonky2::{ + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::circuit_builder::CircuitBuilder, + }; + + use super::QueryPublicInputs; + + use super::PublicInputs; + + const S: usize = 10; + #[derive(Clone, Debug)] + struct TestPublicInputs<'a> { + pis: &'a [F], + } + + impl UserCircuit for TestPublicInputs<'_> { + type Wires = Vec; + + fn build(c: &mut CircuitBuilder) -> Self::Wires { + let targets = c.add_virtual_target_arr::<{ PublicInputs::::total_len() }>(); + let pi_targets = PublicInputs::::from_slice(targets.as_slice()); + pi_targets.register_args(c); + pi_targets.to_vec() + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + pw.set_target_arr(wires, self.pis) + } + } + + #[test] + fn test_query_public_inputs() { + let pis_raw: Vec = random_vector::(PublicInputs::::total_len()).to_fields(); + let pis = PublicInputs::::from_slice(pis_raw.as_slice()); + // check public inputs are constructed correctly + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::TreeHash)], + pis.to_hash_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OutputValues)], + pis.to_values_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::NumMatching)], + &[*pis.to_count_raw()], + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::OpIds)], + pis.to_ops_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexValue)], + pis.to_index_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinValue)], + pis.to_min_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxValue)], + pis.to_max_value_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MinQuery)], + pis.to_min_query_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::MaxQuery)], + pis.to_max_query_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::IndexIds)], + pis.to_index_ids_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::Overflow)], + &[*pis.to_overflow_raw()], + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::ComputationalHash)], + pis.to_computational_hash_raw(), + ); + assert_eq!( + &pis_raw[PublicInputs::::to_range(QueryPublicInputs::PlaceholderHash)], + pis.to_placeholder_hash_raw(), + ); + // use public inputs in circuit + let test_circuit = TestPublicInputs { pis: &pis_raw }; + let proof = run_circuit::(test_circuit); + assert_eq!(proof.public_inputs, pis_raw); + } +} diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index 26dea52e3..2fd9ae453 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -12,7 +12,9 @@ use mp2_common::{ C, D, F, }; use plonky2::plonk::{ - circuit_data::VerifierOnlyCircuitData, config::Hasher, proof::ProofWithPublicInputs, + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, + config::Hasher, + proof::ProofWithPublicInputs, }; use recursion_framework::{ circuit_builder::{CircuitWithUniversalVerifier, CircuitWithUniversalVerifierBuilder}, @@ -24,33 +26,34 @@ use serde::{Deserialize, Serialize}; use crate::{ query::{ - self, - aggregation::QueryBounds, - api::{CircuitInput as QueryCircuitInput, Parameters as QueryParams}, computational_hash_ids::ColumnIDs, pi_len as query_pi_len, - universal_circuit::universal_circuit_inputs::{ - BasicOperation, Placeholders, ResultStructure, + universal_circuit::{ + output_no_aggregation::Circuit as OutputNoAggCircuit, + universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, + universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitParams}, }, + utils::QueryBounds, }, revelation::{ placeholders_check::CheckPlaceholderGadget, revelation_unproven_offset::{ generate_dummy_row_proof_inputs, RecursiveCircuitWires as RecursiveCircuitWiresUnprovenOffset, + TabularQueryOutputModifiers, }, }, }; use super::{ - num_query_io, pi_len, + pi_len, revelation_unproven_offset::{ - RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, + CircuitBuilderParams, RecursiveCircuitInputs as RecursiveCircuitInputsUnporvenOffset, RevelationCircuit as RevelationCircuitUnprovenOffset, RowPath, }, revelation_without_results_tree::{ - CircuitBuilderParams, RecursiveCircuitInputs, RecursiveCircuitWires, - RevelationWithoutResultsTreeCircuit, + CircuitBuilderParams as CircuitBuilderParamsNoResultsTree, RecursiveCircuitInputs, + RecursiveCircuitWires, RevelationWithoutResultsTreeCircuit, }, }; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -150,7 +153,7 @@ pub struct Parameters< [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, - [(); num_query_io::()]:, + [(); query_pi_len::()]:, { revelation_no_results_tree: CircuitWithUniversalVerifier< F, @@ -221,7 +224,7 @@ pub enum CircuitInput< >, }, UnprovenOffset { - row_proofs: Vec, + row_proofs: Vec>, preprocessing_proof: ProofWithPublicInputs, revelation_circuit: RevelationCircuitUnprovenOffset< ROW_TREE_MAX_DEPTH, @@ -232,7 +235,7 @@ pub enum CircuitInput< { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) }, >, dummy_row_proof_input: Option< - QueryCircuitInput< + UniversalCircuitInput< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -267,7 +270,6 @@ where [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); query_pi_len::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, { @@ -290,7 +292,7 @@ where ) -> Result { let query_proof = ProofWithVK::deserialize(&query_proof)?; let preprocessing_proof = deserialize_proof(&preprocessing_proof)?; - let placeholder_hash_ids = query::api::CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -376,10 +378,10 @@ where .map(|(i, row)| { row_paths[i] = row.path.clone(); result_values[i] = row.result.clone(); - ProofWithVK::deserialize(&row.proof) + deserialize_proof(&row.proof) }) .collect::>>()?; - let placeholder_hash_ids = query::api::CircuitInput::< + let placeholder_hash_ids = UniversalCircuitInput::< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -395,11 +397,14 @@ where let revelation_circuit = RevelationCircuitUnprovenOffset::new( row_paths, + [column_ids.primary, column_ids.secondary], &results_structure.output_ids, result_values, - limit, - offset, - results_structure.distinct.unwrap_or(false), + TabularQueryOutputModifiers::new( + limit, + offset, + results_structure.distinct.unwrap_or_default(), + ), placeholder_inputs, )?; @@ -435,18 +440,18 @@ impl< > where [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); num_query_io::()]:, + [(); query_pi_len::()]:, [(); >::HASH_SIZE]:, [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, - [(); query_pi_len::()]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); pi_len::()]:, { pub fn build( query_circuit_set: &RecursiveCircuits, + universal_circuit_vk: VerifierCircuitData, preprocessing_circuit_set: &RecursiveCircuits, preprocessing_vk: &VerifierOnlyCircuitData, ) -> Self { @@ -455,12 +460,17 @@ where D, { pi_len::() }, >::new::(default_config(), REVELATION_CIRCUIT_SET_SIZE); - let build_parameters = CircuitBuilderParams { + let build_parameters = CircuitBuilderParamsNoResultsTree { query_circuit_set: query_circuit_set.clone(), preprocessing_circuit_set: preprocessing_circuit_set.clone(), preprocessing_vk: preprocessing_vk.clone(), }; - let revelation_no_results_tree = builder.build_circuit(build_parameters.clone()); + let revelation_no_results_tree = builder.build_circuit(build_parameters); + let build_parameters = CircuitBuilderParams { + universal_query_vk: universal_circuit_vk, + preprocessing_circuit_set: preprocessing_circuit_set.clone(), + preprocessing_vk: preprocessing_vk.clone(), + }; let revelation_unproven_offset = builder.build_circuit(build_parameters); let circuits = vec![ @@ -491,11 +501,12 @@ where >, query_circuit_set: &RecursiveCircuits, query_params: Option< - &QueryParams< + &UniversalQueryCircuitParams< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, MAX_NUM_ITEMS_PER_OUTPUT, + OutputNoAggCircuit, >, >, ) -> Result> { @@ -528,8 +539,11 @@ where dummy_row_proof_input, } => { let row_proofs = if let Some(input) = dummy_row_proof_input { - let proof = query_params.unwrap().generate_proof(input)?; - let proof = ProofWithVK::deserialize(&proof)?; + let proof = if let UniversalCircuitInput::QueryNoAgg(input) = input { + query_params.unwrap().generate_proof(&input)? + } else { + unreachable!("Universal circuit should only be used for queries with no aggregation operations") + }; row_proofs .into_iter() .chain(repeat(proof)) @@ -567,9 +581,12 @@ where #[cfg(test)] mod tests { - use crate::test_utils::{ - TestRevelationData, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, - MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, + use crate::{ + query::pi_len as query_pi_len, + test_utils::{ + TestRevelationData, MAX_NUM_COLUMNS, MAX_NUM_ITEMS_PER_OUTPUT, MAX_NUM_OUTPUTS, + MAX_NUM_PLACEHOLDERS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, + }, }; use itertools::Itertools; use mp2_common::{ @@ -578,7 +595,7 @@ mod tests { types::HashOutput, C, D, F, }; - use mp2_test::log::init_logging; + use mp2_test::{circuit::TestDummyCircuit, log::init_logging}; use plonky2::{ field::types::PrimeField64, hash::hash_types::HashOut, plonk::config::GenericHashOut, }; @@ -588,18 +605,21 @@ mod tests { ivc::PublicInputs as PreprocessingPI, query::{ computational_hash_ids::{ColumnIDs, Identifiers}, - public_inputs::PublicInputs as QueryPI, + public_inputs::PublicInputsQueryCircuits as QueryPI, }, revelation::{ api::{CircuitInput, Parameters}, - num_query_io, - tests::compute_results_from_query_proof, + tests::compute_results_from_query_proof_outputs, PublicInputs, NUM_PREPROCESSING_IO, }, }; #[test] fn test_api() { + use mp2_common::utils::FromFields; + + use crate::query::universal_circuit::universal_query_gadget::OutputValues; + init_logging(); const ROW_TREE_MAX_DEPTH: usize = 10; @@ -609,10 +629,12 @@ mod tests { F, C, D, - { num_query_io::() }, + { query_pi_len::() }, >::default(); let preprocessing_circuits = TestingRecursiveCircuits::::default(); + let dummy_universal_circuit = + TestDummyCircuit::<{ query_pi_len::() }>::build(); println!("building params"); let params = Parameters::< ROW_TREE_MAX_DEPTH, @@ -625,6 +647,7 @@ mod tests { MAX_NUM_PLACEHOLDERS, >::build( query_circuits.get_recursive_circuit_set(), + dummy_universal_circuit.circuit_data().verifier_data(), preprocessing_circuits.get_recursive_circuit_set(), preprocessing_circuits .verifier_data_for_input_proofs::<1>() @@ -636,7 +659,6 @@ mod tests { let test_data = TestRevelationData::sample(42, 76); let query_pi = QueryPI::::from_slice(test_data.query_pi_raw()); - // generate query proof let [query_proof] = query_circuits .generate_input_proofs::<1>([test_data.query_pi_raw().try_into().unwrap()]) @@ -651,7 +673,6 @@ mod tests { .unwrap(); let preprocessing_pi = PreprocessingPI::from_slice(&preprocessing_proof.public_inputs); let preprocessing_proof = serialize_proof(&preprocessing_proof).unwrap(); - let input = CircuitInput::new_revelation_aggregated( query_proof, preprocessing_proof, @@ -683,10 +704,14 @@ mod tests { // check entry count assert_eq!(query_pi.num_matching_rows(), pi.entry_count(),); // check results and overflow - let (result, overflow) = compute_results_from_query_proof(&query_pi); + let result = compute_results_from_query_proof_outputs( + query_pi.num_matching_rows(), + OutputValues::::from_fields(query_pi.to_values_raw()), + &query_pi.operation_ids(), + ); assert_eq!(pi.num_results().to_canonical_u64(), 1,); assert_eq!(pi.result_values()[0], result,); - assert_eq!(pi.overflow_flag(), overflow,); + assert_eq!(pi.overflow_flag(), query_pi.overflow_flag(),); // check computational hash // first, compute the final computational hash let metadata_hash = HashOut::::from_partial(preprocessing_pi.metadata_hash()); diff --git a/verifiable-db/src/revelation/batching.rs b/verifiable-db/src/revelation/batching.rs new file mode 100644 index 000000000..2642caf7d --- /dev/null +++ b/verifiable-db/src/revelation/batching.rs @@ -0,0 +1,498 @@ +use mp2_common::{types::CBuilder, u256::CircuitBuilderU256, utils::FromTargets, F}; +use plonky2::iop::{target::Target, witness::PartialWitness}; +use serde::{Deserialize, Serialize}; + +use crate::{ + ivc::PublicInputs as OriginalTreePublicInputs, + query::{ + batching::public_inputs::PublicInputs as QueryProofPublicInputs, + universal_circuit::universal_query_gadget::OutputValuesTarget, + }, +}; + +use super::revelation_without_results_tree::{ + QueryProofInputWires, RevelationWithoutResultsTreeCircuit, RevelationWithoutResultsTreeWires, +}; + +impl<'a, const S: usize> From<&'a QueryProofPublicInputs<'a, Target, S>> for QueryProofInputWires +where + [(); S - 1]:, +{ + fn from(value: &'a QueryProofPublicInputs) -> Self { + Self { + tree_hash: value.tree_hash_target(), + results: OutputValuesTarget::from_targets(value.to_values_raw()), + entry_count: value.num_matching_rows_target(), + overflow: value.overflow_flag_target().target, + placeholder_hash: value.placeholder_hash_target(), + computational_hash: value.computational_hash_target(), + min_primary: value.min_primary_target(), + max_primary: value.max_primary_target(), + ops: value.operation_ids_target(), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RevelationCircuitBatching< + const L: usize, + const S: usize, + const PH: usize, + const PP: usize, +>(RevelationWithoutResultsTreeCircuit); + +impl + RevelationCircuitBatching +where + [(); S - 1]:, +{ + pub(crate) fn build( + b: &mut CBuilder, + query_proof: &QueryProofPublicInputs, + original_tree_proof: &OriginalTreePublicInputs, + ) -> RevelationWithoutResultsTreeWires { + let wires = RevelationWithoutResultsTreeCircuit::build_core( + b, + query_proof.into(), + original_tree_proof, + ); + // additional constraints on boundary rows to ensure completeness of proven rows + // (i.e., that we look at all the rows with primary and secondary index values in the query range) + + let left_boundary_row = query_proof.left_boundary_row_target(); + + // 1. Either the index tree node of left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_primary + let smaller_than_min_primary = b.is_less_than_u256( + &left_boundary_row.index_node_info.predecessor_info.value, + &query_proof.min_primary_target(), + ); + // assert not pQ.left_boundary_row.index_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.index_node_data.predecessor_value < pQ.MIN_primary + let constraint = b.and( + left_boundary_row.index_node_info.predecessor_info.is_found, + smaller_than_min_primary, + ); + b.connect( + left_boundary_row + .index_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + // 2. Either the rows tree node storing left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_secondary + let smaller_than_min_secondary = b.is_less_than_u256( + &left_boundary_row.row_node_info.predecessor_info.value, + &query_proof.min_secondary_target(), + ); + // assert not pQ.left_boundary_row.row_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.row_node_data.predecessor_value < pQ.MIN_secondary + let constraint = b.and( + left_boundary_row.row_node_info.predecessor_info.is_found, + smaller_than_min_secondary, + ); + b.connect( + left_boundary_row + .row_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + let right_boundary_row = query_proof.right_boundary_row_target(); + + // 3. Either the index tree node of right boundary row has no successor, or + // the value of the successor is greater than MAX_primary + let greater_than_max_primary = b.is_greater_than_u256( + &right_boundary_row.index_node_info.successor_info.value, + &query_proof.max_primary_target(), + ); + // assert not pQ.right_boundary_row.index_node_data.successor_info.is_found or + // pQ.right_boundary_row.index_node_data.successor_value > pQ.MAX_primary + let constraint = b.and( + right_boundary_row.index_node_info.successor_info.is_found, + greater_than_max_primary, + ); + b.connect( + right_boundary_row + .index_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + + // 4. Either the rows tree node storing right boundary row has no successor, or + // the value of the successor is greater than MAX_secondary + let greater_than_max_secondary = b.is_greater_than_u256( + &right_boundary_row.row_node_info.successor_info.value, + &query_proof.max_secondary_target(), + ); + // assert not pQ.right_boundary_row.row_node_data.successor_info.is_found or + // pQ.right_boundary_row.row_node_data.successor_value > pQ.MAX_secondary + let constraint = b.and( + right_boundary_row.row_node_info.successor_info.is_found, + greater_than_max_secondary, + ); + b.connect( + right_boundary_row + .row_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + wires + } + + pub(crate) fn assign( + &self, + pw: &mut PartialWitness, + wires: &RevelationWithoutResultsTreeWires, + ) { + self.0.assign(pw, wires) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{ + array::ToField, + poseidon::{flatten_poseidon_hash_value, H}, + types::CBuilder, + utils::{FromFields, ToFields}, + C, D, F, + }; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{ + field::types::Field, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::config::Hasher, + }; + use rand::{seq::SliceRandom, thread_rng, Rng}; + + use crate::{ + ivc::PublicInputs as OriginalTreePublicInputs, + query::{ + aggregation::{QueryBoundSource, QueryBounds}, + batching::{ + public_inputs::{ + tests::gen_values_in_range, PublicInputs as QueryProofPublicInputs, + QueryPublicInputs, + }, + row_chunk::tests::BoundaryRowData, + }, + computational_hash_ids::AggregationOperation, + universal_circuit::{ + universal_circuit_inputs::Placeholders, universal_query_gadget::OutputValues, + }, + }, + revelation::{ + revelation_without_results_tree::{ + RevelationWithoutResultsTreeCircuit, RevelationWithoutResultsTreeWires, + }, + tests::{compute_results_from_query_proof_outputs, TestPlaceholders}, + PublicInputs, NUM_PREPROCESSING_IO, + }, + test_utils::{random_aggregation_operations, random_original_tree_proof}, + }; + + use super::RevelationCircuitBatching; + + // L: maximum number of results + // S: maximum number of items in each result + // PH: maximum number of unique placeholder IDs and values bound for query + // PP: maximum number of placeholders present in query (may be duplicate, PP >= PH) + const L: usize = 5; + const S: usize = 10; + const PH: usize = 10; + const PP: usize = 20; + + // Real number of the placeholders + const NUM_PLACEHOLDERS: usize = 6; + + const QUERY_PI_LEN: usize = QueryProofPublicInputs::::total_len(); + + #[derive(Clone, Debug)] + struct TestRevelationBatchingCircuit<'a> { + c: RevelationWithoutResultsTreeCircuit, + query_proof: &'a [F], + original_tree_proof: &'a [F], + } + + impl UserCircuit for TestRevelationBatchingCircuit<'_> { + // Circuit wires + query proof + original tree proof (IVC proof) + type Wires = ( + RevelationWithoutResultsTreeWires, + Vec, + Vec, + ); + + fn build(b: &mut CBuilder) -> Self::Wires { + let query_proof = b.add_virtual_target_arr::().to_vec(); + let original_tree_proof = b.add_virtual_target_arr::().to_vec(); + + let query_pi = QueryProofPublicInputs::from_slice(&query_proof); + let original_tree_pi = OriginalTreePublicInputs::from_slice(&original_tree_proof); + + let wires = RevelationCircuitBatching::build(b, &query_pi, &original_tree_pi); + + (wires, query_proof, original_tree_proof) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.c.assign(pw, &wires.0); + pw.set_target_arr(&wires.1, self.query_proof); + pw.set_target_arr(&wires.2, self.original_tree_proof); + } + } + + /// Generate a random query proof. + fn random_query_proof( + entry_count: u32, + ops: &[F; S], + test_placeholders: &TestPlaceholders, + ) -> Vec { + let [mut proof] = QueryProofPublicInputs::sample_from_ops(ops); + + let [count_range, min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, left_row_range, right_row_range] = + [ + QueryPublicInputs::NumMatching, + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, + QueryPublicInputs::PlaceholderHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, + ] + .map(QueryProofPublicInputs::::to_range); + + // Set the count, minimum, maximum query and the placeholder hash. + [ + (count_range, vec![entry_count.to_field()]), + (min_query_primary, test_placeholders.min_query.to_fields()), + (max_query_primary, test_placeholders.max_query.to_fields()), + ( + p_hash_range, + test_placeholders.query_placeholder_hash.to_fields(), + ), + ] + .into_iter() + .for_each(|(range, fields)| proof[range].copy_from_slice(&fields)); + + // Set boundary rows to satisfy constraints for completeness + let rng = &mut thread_rng(); + let min_secondary = U256::from_fields(&proof[min_query_secondary]); + let max_secondary = U256::from_fields(&proof[max_query_secondary]); + let placeholders = + Placeholders::new_empty(test_placeholders.min_query, test_placeholders.max_query); + let query_bounds = QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap(); + let mut left_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + // for predecessor of `left_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.min_query_primary() == U256::ZERO { + left_boundary_row.index_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds.min_query_primary() - U256::from(1), + ); + left_boundary_row.index_node_info.predecessor_info.value = predecessor_value; + } + // for predecessor of `left_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || min_secondary == U256::ZERO { + left_boundary_row.row_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = + gen_values_in_range(rng, U256::ZERO, min_secondary - U256::from(1)); + left_boundary_row.row_node_info.predecessor_info.value = predecessor_value; + } + let mut right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + // for successor of `right_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.max_query_primary() == U256::MAX { + right_boundary_row.index_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range( + rng, + query_bounds.max_query_primary() + U256::from(1), + U256::MAX, + ); + right_boundary_row.index_node_info.successor_info.value = successor_value; + } + // for successor of `right_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || max_secondary == U256::MAX { + right_boundary_row.row_node_info.successor_info.is_found = false; + } else { + let [successor_value] = + gen_values_in_range(rng, max_secondary + U256::from(1), U256::MAX); + right_boundary_row.row_node_info.successor_info.value = successor_value; + } + + proof[left_row_range].copy_from_slice(&left_boundary_row.to_fields()); + proof[right_row_range].copy_from_slice(&right_boundary_row.to_fields()); + + proof + } + + /// Utility function for testing the revelation circuit with results tree + fn test_revelation_batching_circuit(ops: &[F; S], entry_count: Option) { + let rng = &mut thread_rng(); + + // Generate the testing placeholder data. + let test_placeholders = TestPlaceholders::sample(NUM_PLACEHOLDERS); + + // Generate the query proof. + let entry_count = entry_count.unwrap_or_else(|| rng.gen()); + let query_proof = random_query_proof(entry_count, ops, &test_placeholders); + let query_pi = QueryProofPublicInputs::<_, S>::from_slice(&query_proof); + + // Generate the original tree proof (IVC proof). + let original_tree_proof = random_original_tree_proof(query_pi.tree_hash()); + let original_tree_pi = OriginalTreePublicInputs::from_slice(&original_tree_proof); + + // Construct the test circuit. + let test_circuit = TestRevelationBatchingCircuit { + c: (&test_placeholders).into(), + query_proof: &query_proof, + original_tree_proof: &original_tree_proof, + }; + + // Prove for the test circuit. + let proof = run_circuit::(test_circuit); + let pi = PublicInputs::<_, L, S, PH>::from_slice(&proof.public_inputs); + + let entry_count = query_pi.num_matching_rows(); + + // Check the public inputs. + // Original block hash + assert_eq!( + pi.original_block_hash(), + original_tree_pi.block_hash_fields() + ); + // Computational hash + { + // H(pQ.C || placeholder_ids_hash || pQ.M) + let inputs = query_pi + .to_computational_hash_raw() + .iter() + .chain(&test_placeholders.placeholder_ids_hash.to_fields()) + .chain(original_tree_pi.metadata_hash()) + .cloned() + .collect_vec(); + let exp_hash = H::hash_no_pad(&inputs); + + assert_eq!( + pi.flat_computational_hash(), + flatten_poseidon_hash_value(exp_hash), + ); + } + // Number of placeholders + assert_eq!( + pi.num_placeholders(), + test_placeholders + .check_placeholder_inputs + .num_placeholders + .to_field() + ); + // Placeholder values + assert_eq!( + pi.placeholder_values(), + test_placeholders + .check_placeholder_inputs + .placeholder_values + ); + // Entry count + assert_eq!(pi.entry_count(), entry_count); + // check results + let result = compute_results_from_query_proof_outputs( + query_pi.num_matching_rows(), + OutputValues::::from_fields(query_pi.to_values_raw()), + &query_pi.operation_ids(), + ); + let mut exp_results = [[U256::ZERO; S]; L]; + exp_results[0] = result; + assert_eq!(pi.result_values(), exp_results); + // overflow flag + assert_eq!(pi.overflow_flag(), query_pi.overflow_flag()); + // Query limit + assert_eq!(pi.query_limit(), F::ZERO); + // Query offset + assert_eq!(pi.query_offset(), F::ZERO); + } + + #[test] + fn test_revelation_batching_simple() { + // Generate the random operations and set the first operation to SUM + // (not ID which should not be present in the aggregation). + let mut ops: [_; S] = random_aggregation_operations(); + ops[0] = AggregationOperation::SumOp.to_field(); + + test_revelation_batching_circuit(&ops, None); + } + + // Test for COUNT operation. + #[test] + fn test_revelation_batching_for_op_count() { + // Set the first operation to COUNT. + let mut ops: [_; S] = random_aggregation_operations(); + ops[0] = AggregationOperation::CountOp.to_field(); + + test_revelation_batching_circuit(&ops, None); + } + + // Test for AVG operation. + #[test] + fn test_revelation_batching_for_op_avg() { + // Set the first operation to AVG. + let mut ops: [_; S] = random_aggregation_operations(); + ops[0] = AggregationOperation::AvgOp.to_field(); + + test_revelation_batching_circuit(&ops, None); + } + + // Test for AVG operation with zero entry count. + #[test] + fn test_revelation_batching_for_op_avg_with_no_entries() { + // Set the first operation to AVG. + let mut ops: [_; S] = random_aggregation_operations(); + ops[0] = AggregationOperation::AvgOp.to_field(); + + test_revelation_batching_circuit(&ops, Some(0)); + } + + // Test for no AVG operation with zero entry count. + #[test] + fn test_revelation_batching_for_no_op_avg_with_no_entries() { + // Initialize the all operations to SUM or COUNT (not AVG). + let mut rng = thread_rng(); + let ops = array::from_fn(|_| { + [AggregationOperation::SumOp, AggregationOperation::CountOp] + .choose(&mut rng) + .unwrap() + .to_field() + }); + + test_revelation_batching_circuit(&ops, Some(0)); + } +} diff --git a/verifiable-db/src/revelation/mod.rs b/verifiable-db/src/revelation/mod.rs index eaa4045ff..22d5dcb46 100644 --- a/verifiable-db/src/revelation/mod.rs +++ b/verifiable-db/src/revelation/mod.rs @@ -1,6 +1,6 @@ //! Module including the revelation circuits for query -use crate::{ivc::NUM_IO, query::pi_len as query_pi_len}; +use crate::ivc::NUM_IO; use mp2_common::F; pub mod api; @@ -20,15 +20,12 @@ pub const fn pi_len() -> usize } pub const NUM_PREPROCESSING_IO: usize = NUM_IO; -pub const fn num_query_io() -> usize { - query_pi_len::() -} #[cfg(test)] pub(crate) mod tests { use super::*; use crate::query::{ computational_hash_ids::{AggregationOperation, PlaceholderIdentifier}, - public_inputs::PublicInputs as QueryProofPublicInputs, + universal_circuit::universal_query_gadget::OutputValues, }; use alloy::primitives::U256; use itertools::Itertools; @@ -144,10 +141,7 @@ pub(crate) mod tests { // Re-compute the placeholder hash from placeholder_pairs and minmum, // maximum query bounds. Then check it should be same with the specified // final placeholder hash. - let (min_i1, max_i1) = ( - check_placeholder_inputs.placeholder_values[0], - check_placeholder_inputs.placeholder_values[1], - ); + let (min_i1, max_i1) = check_placeholder_inputs.primary_query_bounds(); let placeholder_hash = H::hash_no_pad(&placeholder_hash_payload); // query_placeholder_hash = H(placeholder_hash || min_i2 || max_i2) let inputs = placeholder_hash @@ -186,23 +180,24 @@ pub(crate) mod tests { } } - /// Compute the query results from the proof, and it returns the results and overflow flag. - pub(crate) fn compute_results_from_query_proof( - query_pi: &QueryProofPublicInputs, - ) -> ([U256; S], bool) + /// Compute the query results from the query proof outputs, and it returns the results. + pub(crate) fn compute_results_from_query_proof_outputs( + entry_count: F, + output_values: OutputValues, + ops: &[F; S], + ) -> [U256; S] where [(); S - 1]:, { // Convert the entry count to an Uint256. - let entry_count = U256::from(query_pi.num_matching_rows().to_canonical_u64()); + let entry_count = U256::from(entry_count.to_canonical_u64()); let [op_avg, op_count] = [AggregationOperation::AvgOp, AggregationOperation::CountOp].map(|op| op.to_field()); // Compute the results array, and deal with AVG and COUNT operations if any. - let ops = query_pi.operation_ids(); - let result = array::from_fn(|i| { - let value = query_pi.value_at_index(i); + array::from_fn(|i| { + let value = output_values.value_at_index(i); let op = ops[i]; if op == op_avg { @@ -212,8 +207,6 @@ pub(crate) mod tests { } else { value } - }); - - (result, query_pi.overflow_flag()) + }) } } diff --git a/verifiable-db/src/revelation/placeholders_check.rs b/verifiable-db/src/revelation/placeholders_check.rs index 84c85fcfa..928ab3a68 100644 --- a/verifiable-db/src/revelation/placeholders_check.rs +++ b/verifiable-db/src/revelation/placeholders_check.rs @@ -2,12 +2,12 @@ //! compute and return the `num_placeholders` and the `placeholder_ids_hash`. use crate::query::{ - aggregation::QueryBounds, computational_hash_ids::PlaceholderIdentifier, universal_circuit::{ universal_circuit_inputs::{PlaceholderId, Placeholders}, - universal_query_circuit::QueryBound, + universal_query_gadget::QueryBound, }, + utils::QueryBounds, }; use alloy::primitives::U256; use anyhow::{ensure, Result}; @@ -20,7 +20,7 @@ use mp2_common::{ }, types::CBuilder, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{FromFields, SelectHashBuilder, ToFields, ToTargets}, + utils::{FromFields, HashBuilder, ToFields, ToTargets}, F, }; use plonky2::{ @@ -206,7 +206,7 @@ impl CheckPlaceholderGadget { }; let to_be_checked_placeholders = placeholder_hash_ids .into_iter() - .map(&compute_checked_placeholder_for_id) + .map(compute_checked_placeholder_for_id) .collect::>>()?; // compute placeholders data to be hashed for secondary query bounds let min_query_secondary = @@ -302,6 +302,11 @@ impl CheckPlaceholderGadget { .zip(&self.secondary_query_bound_placeholders) .for_each(|(t, v)| v.assign(pw, t)); } + // Return the query bounds on the primary index, taken from the placeholder values + #[cfg(test)] // used only in test for now + pub(crate) fn primary_query_bounds(&self) -> (U256, U256) { + (self.placeholder_values[0], self.placeholder_values[1]) + } } /// This gadget checks that the placeholders identifiers and values employed to diff --git a/verifiable-db/src/revelation/revelation_unproven_offset.rs b/verifiable-db/src/revelation/revelation_unproven_offset.rs index 83207862c..aac6d69aa 100644 --- a/verifiable-db/src/revelation/revelation_unproven_offset.rs +++ b/verifiable-db/src/revelation/revelation_unproven_offset.rs @@ -16,7 +16,7 @@ use mp2_common::{ default_config, group_hashing::CircuitBuilderGroupHashing, poseidon::{flatten_poseidon_hash_target, H}, - proof::ProofWithVK, + proof::verify_proof_fixed_circuit, public_inputs::PublicInputCommon, serialization::{ deserialize, deserialize_array, deserialize_long_array, serialize, serialize_array, @@ -24,17 +24,17 @@ use mp2_common::{ }, types::{CBuilder, HashOutput}, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{Fieldable, SelectHashBuilder, ToTargets}, + utils::{Fieldable, HashBuilder, ToTargets}, C, D, F, }; use plonky2::{ - field::types::PrimeField64, hash::hash_types::HashOutTarget, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, plonk::{ + circuit_data::{VerifierCircuitData, VerifierOnlyCircuitData}, config::Hasher, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, }, @@ -42,35 +42,82 @@ use plonky2::{ use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::{ circuit_builder::CircuitLogicWires, - framework::{ - RecursiveCircuits, RecursiveCircuitsVerifierGagdet, RecursiveCircuitsVerifierTarget, - }, + framework::{RecursiveCircuits, RecursiveCircuitsVerifierGagdet}, }; use serde::{Deserialize, Serialize}; use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - aggregation::{ChildPosition, NodeInfo, QueryBounds, QueryHashNonExistenceCircuits}, - api::CircuitInput as QueryCircuitInput, - computational_hash_ids::{AggregationOperation, ColumnIDs, ResultIdentifier}, + computational_hash_ids::{ColumnIDs, ResultIdentifier}, merkle_path::{MerklePathGadget, MerklePathTargetInputs}, - pi_len, - public_inputs::PublicInputs as QueryProofPublicInputs, + public_inputs::PublicInputsUniversalCircuit as QueryProofPublicInputs, universal_circuit::{ build_cells_tree, - universal_circuit_inputs::{BasicOperation, Placeholders, ResultStructure}, + universal_circuit_inputs::{ + BasicOperation, ColumnCell, Placeholders, ResultStructure, RowCells, + }, + universal_query_circuit::{UniversalCircuitInput, UniversalQueryCircuitInputs}, }, + utils::{ChildPosition, NodeInfo, QueryBounds}, }, }; use super::{ - num_query_io, pi_len as revelation_pi_len, + pi_len as revelation_pi_len, placeholders_check::{CheckPlaceholderGadget, CheckPlaceholderInputWires}, - revelation_without_results_tree::CircuitBuilderParams, PublicInputs, NUM_PREPROCESSING_IO, }; +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +/// Data structure containing all the information needed to verify the membership of +/// a row in a tree representing a table +pub struct RowPath { + /// Info about the node of the row tree storing the row + pub(crate) row_node_info: NodeInfo, + /// Info about the nodes in the path of the rows tree for the node storing the row; The `ChildPosition` refers to + /// the position of the previous node in the path as a child of the current node + pub(crate) row_tree_path: Vec<(NodeInfo, ChildPosition)>, + /// Hash of the siblings of the node in the rows tree path (except for the root) + pub(crate) row_path_siblings: Vec>, + /// Info about the node of the index tree storing the rows tree containing the row + pub(crate) index_node_info: NodeInfo, + /// Info about the nodes in the path of the index tree for the index_node; The `ChildPosition` refers to + /// the position of the previous node in the path as a child of the current node + pub(crate) index_tree_path: Vec<(NodeInfo, ChildPosition)>, + /// Hash of the siblings of the nodes in the index tree path (except for the root) + pub(crate) index_path_siblings: Vec>, +} + +impl RowPath { + /// Instantiate a new instance of `RowPath` for a given proven row from the following input data: + /// - `row_node_info`: data about the node of the row tree storing the row + /// - `row_tree_path`: data about the nodes in the path of the rows tree for the node storing the row; + /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node + /// - `row_path_siblings`: hash of the siblings of the node in the rows tree path (except for the root) + /// - `index_node_info`: data about the node of the index tree storing the rows tree containing the row + /// - `index_tree_path`: data about the nodes in the path of the index tree for the index_node; + /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node + /// - `index_path_siblings`: hash of the siblings of the nodes in the index tree path (except for the root) + pub fn new( + row_node_info: NodeInfo, + row_tree_path: Vec<(NodeInfo, ChildPosition)>, + row_path_siblings: Vec>, + index_node_info: NodeInfo, + index_tree_path: Vec<(NodeInfo, ChildPosition)>, + index_path_siblings: Vec>, + ) -> Self { + Self { + row_node_info, + row_tree_path, + row_path_siblings, + index_node_info, + index_tree_path, + index_path_siblings, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] /// Target for all the information about nodes in the path needed by this revelation circuit struct NodeInfoTarget { @@ -105,6 +152,26 @@ impl NodeInfoTarget { } } +/// Data structure containing the parameters found in tabular +/// queries that specify which outputs should be returned +#[derive(Clone, Debug)] +pub(crate) struct TabularQueryOutputModifiers { + limit: u32, + offset: u32, + /// Boolean flag specifying whether DISTINCT keyword must be applied to results + distinct: bool, +} + +impl TabularQueryOutputModifiers { + pub(crate) fn new(limit: u32, offset: u32, distinct: bool) -> Self { + Self { + limit, + offset, + distinct, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct RevelationWires< const ROW_TREE_MAX_DEPTH: usize, @@ -143,6 +210,7 @@ pub(crate) struct RevelationWires< deserialize_with = "deserialize_array" )] is_row_node_leaf: [BoolTarget; L], + index_column_ids: [Target; 2], #[serde( serialize_with = "serialize_array", deserialize_with = "deserialize_array" @@ -203,6 +271,8 @@ pub struct RevelationCircuit< /// Info about the nodes of the index tree that stores the rows trees where each of /// the L rows being proven are located index_node_info: [NodeInfo; L], + /// Identifiers of the indexed columns + index_column_ids: [F; 2], /// Actual number of items per-row included in the results. num_actual_items_per_row: usize, /// Ids of the output items included in the results for each row @@ -226,55 +296,6 @@ pub struct RevelationCircuit< check_placeholder_inputs: CheckPlaceholderGadget, } -#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] -/// Data structure containing all the information needed to verify the membership of -/// a row in a tree representing a table -pub struct RowPath { - /// Info about the node of the row tree storing the row - row_node_info: NodeInfo, - /// Info about the nodes in the path of the rows tree for the node storing the row; The `ChildPosition` refers to - /// the position of the previous node in the path as a child of the current node - row_tree_path: Vec<(NodeInfo, ChildPosition)>, - /// Hash of the siblings of the node in the rows tree path (except for the root) - row_path_siblings: Vec>, - /// Info about the node of the index tree storing the rows tree containing the row - index_node_info: NodeInfo, - /// Info about the nodes in the path of the index tree for the index_node; The `ChildPosition` refers to - /// the position of the previous node in the path as a child of the current node - index_tree_path: Vec<(NodeInfo, ChildPosition)>, - /// Hash of the siblings of the nodes in the index tree path (except for the root) - index_path_siblings: Vec>, -} - -impl RowPath { - /// Instantiate a new instance of `RowPath` for a given proven row from the following input data: - /// - `row_node_info`: data about the node of the row tree storing the row - /// - `row_tree_path`: data about the nodes in the path of the rows tree for the node storing the row; - /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node - /// - `row_path_siblings`: hash of the siblings of the node in the rows tree path (except for the root) - /// - `index_node_info`: data about the node of the index tree storing the rows tree containing the row - /// - `index_tree_path`: data about the nodes in the path of the index tree for the index_node; - /// The `ChildPosition` refers to the position of the previous node in the path as a child of the current node - /// - `index_path_siblings`: hash of the siblings of the nodes in the index tree path (except for the root) - pub fn new( - row_node_info: NodeInfo, - row_tree_path: Vec<(NodeInfo, ChildPosition)>, - row_path_siblings: Vec>, - index_node_info: NodeInfo, - index_tree_path: Vec<(NodeInfo, ChildPosition)>, - index_path_siblings: Vec>, - ) -> Self { - Self { - row_node_info, - row_tree_path, - row_path_siblings, - index_node_info, - index_tree_path, - index_path_siblings, - } - } -} - impl< const ROW_TREE_MAX_DEPTH: usize, const INDEX_TREE_MAX_DEPTH: usize, @@ -290,11 +311,10 @@ where { pub(crate) fn new( row_paths: [RowPath; L], + index_column_ids: [F; 2], item_ids: &[F], results: [Vec; L], - limit: u32, - offset: u32, - distinct: bool, + query_modifiers: TabularQueryOutputModifiers, placeholder_inputs: CheckPlaceholderGadget, ) -> Result { let mut row_tree_paths = [MerklePathGadget::::default(); L]; @@ -338,12 +358,13 @@ where index_tree_paths, row_node_info, index_node_info, + index_column_ids, num_actual_items_per_row, ids: padded_ids.try_into().unwrap(), results: results.try_into().unwrap(), - limit, - offset, - distinct, + limit: query_modifiers.limit, + offset: query_modifiers.offset, + distinct: query_modifiers.distinct, check_placeholder_inputs: placeholder_inputs, }) } @@ -366,8 +387,12 @@ where // computed by the universal query circuit // closure to access the output items of the i-th result let get_result = |i| &results[S * i..S * (i + 1)]; - let [min_query, max_query] = b.add_virtual_u256_arr_unsafe(); // unsafe should be ok since they are later included in placeholder hash + let (min_query_primary, max_query_primary) = ( + row_proofs[0].min_primary_target(), + row_proofs[0].max_primary_target(), + ); let [limit, offset] = b.add_virtual_target_arr(); + let index_column_ids = b.add_virtual_target_arr(); let tree_hash = original_tree_proof.merkle_hash(); let zero = b.zero(); let one = b.one(); @@ -385,7 +410,6 @@ where // this is a requirement to ensure that the check for DISTINCT is sound let mut only_matching_rows = _true; row_proofs.iter().enumerate().for_each(|(i, row_proof)| { - let index_ids = row_proof.index_ids_target(); let is_matching_row = b.is_equal(row_proof.num_matching_rows_target(), one); // ensure that once `is_matching_row = false`, then it will be false for all // subsequent iterations @@ -401,8 +425,8 @@ where .flat_map(|hash| hash.to_targets()) .chain(row_node_info[i].node_min.to_targets()) .chain(row_node_info[i].node_max.to_targets()) - .chain(once(index_ids[1])) - .chain(row_proof.min_value_target().to_targets()) + .chain(once(index_column_ids[1])) + .chain(row_proof.secondary_index_value_target().to_targets()) .chain(row_proof.tree_hash_target().to_targets()) .collect_vec(); let row_node_hash = b.hash_n_to_hash_no_pad::(inputs); @@ -412,7 +436,7 @@ where &row_node_hash, ) }; - let row_path_wires = MerklePathGadget::build(b, row_node_hash, index_ids[1]); + let row_path_wires = MerklePathGadget::build(b, row_node_hash, index_column_ids[1]); let row_tree_root = row_path_wires.root; // compute hash of the index node storing the rows tree containing the current row let index_node_hash = { @@ -422,13 +446,13 @@ where .flat_map(|hash| hash.to_targets()) .chain(index_node_info[i].node_min.to_targets()) .chain(index_node_info[i].node_max.to_targets()) - .chain(once(index_ids[0])) - .chain(row_proof.index_value_target().to_targets()) + .chain(once(index_column_ids[0])) + .chain(row_proof.primary_index_value_target().to_targets()) .chain(row_tree_root.to_targets()) .collect_vec(); b.hash_n_to_hash_no_pad::(inputs) }; - let index_path_wires = MerklePathGadget::build(b, index_node_hash, index_ids[0]); + let index_path_wires = MerklePathGadget::build(b, index_node_hash, index_column_ids[0]); // if the current row is valid, check that the root is the same of the original tree, completing // membership proof for the current row; otherwise, we don't care let root = b.select_hash(is_matching_row, &index_path_wires.root, &tree_hash); @@ -436,14 +460,6 @@ where row_paths.push(row_path_wires.inputs); index_paths.push(index_path_wires.inputs); - // check that the primary index value for the current row is within the query - // bounds (only if the row is valid) - let index_value = row_proof.index_value_target(); - let greater_than_min = b.is_less_or_equal_than_u256(&min_query, &index_value); - let smaller_than_max = b.is_less_or_equal_than_u256(&index_value, &max_query); - let in_range = b.and(greater_than_min, smaller_than_max); - let in_range = b.and(is_matching_row, in_range); - b.connect(in_range.target, is_matching_row.target); // enforce DISTINCT only for actual results: we enforce the i-th actual result is strictly smaller // than the (i+1)-th actual result @@ -489,6 +505,9 @@ where // the proofs b.connect_hashes(row_proof.computational_hash_target(), computational_hash); b.connect_hashes(row_proof.placeholder_hash_target(), placeholder_hash); + // check that query bounds on primary index are the same for all the proofs + b.enforce_equal_u256(&row_proof.min_primary_target(), &min_query_primary); + b.enforce_equal_u256(&row_proof.max_primary_target(), &max_query_primary); overflow = b.or(overflow, row_proof.overflow_flag_target()); }); @@ -499,19 +518,19 @@ where let inputs = placeholder_hash .to_targets() .into_iter() - .chain(min_query.to_targets()) - .chain(max_query.to_targets()) + .chain(min_query_primary.to_targets()) + .chain(max_query_primary.to_targets()) .collect_vec(); b.hash_n_to_hash_no_pad::(inputs) }; let check_placeholder_wires = CheckPlaceholderGadget::build(b, &final_placeholder_hash); b.enforce_equal_u256( - &min_query, + &min_query_primary, &check_placeholder_wires.input_wires.placeholder_values[0], ); b.enforce_equal_u256( - &max_query, + &max_query_primary, &check_placeholder_wires.input_wires.placeholder_values[1], ); @@ -566,6 +585,7 @@ where row_node_info, index_node_info, is_row_node_leaf, + index_column_ids, is_item_included, ids, results, @@ -614,6 +634,7 @@ where .zip(wires.results.iter()) .for_each(|(&value, target)| pw.set_u256_target(target, value)); pw.set_target_arr(&wires.ids, &self.ids); + pw.set_target_arr(&wires.index_column_ids, &self.index_column_ids); pw.set_target(wires.limit, self.limit.to_field()); pw.set_target(wires.offset, self.offset.to_field()); pw.set_bool_target(wires.distinct, self.distinct); @@ -637,7 +658,7 @@ pub(crate) fn generate_dummy_row_proof_inputs< placeholders: &Placeholders, query_bounds: &QueryBounds, ) -> Result< - QueryCircuitInput< + UniversalCircuitInput< MAX_NUM_COLUMNS, MAX_NUM_PREDICATE_OPS, MAX_NUM_RESULT_OPS, @@ -648,55 +669,45 @@ where [(); MAX_NUM_COLUMNS + MAX_NUM_RESULT_OPS]:, [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, [(); MAX_NUM_ITEMS_PER_OUTPUT - 1]:, - [(); pi_len::()]:, [(); >::HASH_SIZE]:, { - // we generate a dummy proof for a dummy node of the index tree with an index value out of range - let query_hashes = QueryHashNonExistenceCircuits::new::< - MAX_NUM_COLUMNS, - MAX_NUM_PREDICATE_OPS, - MAX_NUM_RESULT_OPS, - MAX_NUM_ITEMS_PER_OUTPUT, - >( - column_ids, + // we generate dummy column cells; we can use all dummy values, except for the + // primary index value which must be in the query range + let primary_index_value = query_bounds.min_query_primary(); + let primary_index_column = ColumnCell { + value: primary_index_value, + id: column_ids.primary, + }; + let secondary_index_column = ColumnCell { + value: U256::default(), + id: column_ids.secondary, + }; + let non_indexed_columns = column_ids + .non_indexed_columns() + .iter() + .map(|id| ColumnCell::new(*id, U256::default())) + .collect_vec(); + let cells = RowCells::new( + primary_index_column, + secondary_index_column, + non_indexed_columns, + ); + let universal_query_circuit = UniversalQueryCircuitInputs::new( + &cells, predicate_operations, - results, placeholders, - query_bounds, - false, - )?; - // we generate info about the proven index-tree node; we can use all dummy values, except for the - // node value which must be out of the query range - let node_value = query_bounds.max_query_primary() + U256::from(1); - let node_info = NodeInfo::new( - &HashOutput::default(), - None, // no children, for simplicity - None, - node_value, - U256::default(), - U256::default(), - ); - // The query has no aggregation operations, so by construction of the circuits we - // know that the first aggregate operation is ID, while the remaining ones are dummies - let aggregation_ops = once(AggregationOperation::IdOp) - .chain(repeat(AggregationOperation::default())) - .take(MAX_NUM_ITEMS_PER_OUTPUT) - .collect_vec(); - QueryCircuitInput::new_non_existence_input( - node_info, - None, - None, - node_value, - &[ - column_ids.primary.to_canonical_u64(), - column_ids.secondary.to_canonical_u64(), - ], - &aggregation_ops, - query_hashes, false, query_bounds, - placeholders, - ) + results, + true, // we generate proof for a dummy row + )?; + Ok(UniversalCircuitInput::QueryNoAgg(universal_query_circuit)) +} + +pub struct CircuitBuilderParams { + pub(crate) universal_query_vk: VerifierCircuitData, + pub(crate) preprocessing_circuit_set: RecursiveCircuits, + pub(crate) preprocessing_vk: VerifierOnlyCircuitData, } #[derive(Serialize, Deserialize, Clone, Debug)] @@ -714,10 +725,10 @@ pub struct RecursiveCircuitWires< { revelation_circuit: RevelationWires, #[serde( - serialize_with = "serialize_long_array", - deserialize_with = "deserialize_long_array" + serialize_with = "serialize_array", + deserialize_with = "deserialize_array" )] - row_verifiers: [RecursiveCircuitsVerifierTarget; L], + row_verifiers: [ProofWithPublicInputsTarget; L], #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] preprocessing_proof: ProofWithPublicInputsTarget, } @@ -740,7 +751,7 @@ pub struct RecursiveCircuitInputs< serialize_with = "serialize_long_array", deserialize_with = "deserialize_long_array" )] - pub(crate) row_proofs: [ProofWithVK; L], + pub(crate) row_proofs: [ProofWithPublicInputs; L], pub(crate) preprocessing_proof: ProofWithPublicInputs, pub(crate) query_circuit_set: RecursiveCircuits, } @@ -758,7 +769,6 @@ where [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); S * L]:, - [(); num_query_io::()]:, [(); >::HASH_SIZE]:, { type CircuitBuilderParams = CircuitBuilderParams; @@ -772,11 +782,8 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - let row_verifier = RecursiveCircuitsVerifierGagdet::() }>::new( - default_config(), - &builder_parameters.query_circuit_set, - ); - let row_verifiers = [0; L].map(|_| row_verifier.verify_proof_in_circuit_set(builder)); + let row_verifiers = [0; L] + .map(|_| verify_proof_fixed_circuit(builder, &builder_parameters.universal_query_vk)); let preprocessing_verifier = RecursiveCircuitsVerifierGagdet::::new( default_config(), @@ -788,11 +795,7 @@ where ); let row_pis = row_verifiers .iter() - .map(|verifier| { - QueryProofPublicInputs::from_slice( - verifier.get_public_input_targets::() }>(), - ) - }) + .map(|verifier| QueryProofPublicInputs::from_slice(&verifier.public_inputs)) .collect_vec(); let preprocessing_pi = OriginalTreePublicInputs::from_slice(&preprocessing_proof.public_inputs); @@ -808,8 +811,7 @@ where fn assign_input(&self, inputs: Self::Inputs, pw: &mut PartialWitness) -> Result<()> { for (verifier_target, row_proof) in self.row_verifiers.iter().zip(inputs.row_proofs) { - let (proof, verifier_data) = (&row_proof).into(); - verifier_target.set_target(pw, &inputs.query_circuit_set, proof, verifier_data)?; + pw.set_proof_with_pis_target(verifier_target, &row_proof); } pw.set_proof_with_pis_target(&self.preprocessing_proof, &inputs.preprocessing_proof); inputs.inputs.assign(pw, &self.revelation_circuit); @@ -853,14 +855,19 @@ mod tests { PublicInputs as OriginalTreePublicInputs, }, query::{ - aggregation::{ChildPosition, NodeInfo}, - public_inputs::{PublicInputs as QueryProofPublicInputs, QueryPublicInputs}, + pi_len as query_pi_len, + public_inputs::{ + PublicInputsUniversalCircuit as QueryProofPublicInputs, + QueryPublicInputsUniversalCircuit, + }, + utils::{ChildPosition, NodeInfo}, }, revelation::{ - num_query_io, revelation_unproven_offset::RowPath, tests::TestPlaceholders, + revelation_unproven_offset::{RowPath, TabularQueryOutputModifiers}, + tests::TestPlaceholders, NUM_PREPROCESSING_IO, }, - test_utils::{random_aggregation_operations, random_aggregation_public_inputs}, + test_utils::random_aggregation_operations, }; use super::{RevelationCircuit, RevelationWires}; @@ -907,7 +914,7 @@ mod tests { fn build(c: &mut CircuitBuilder) -> Self::Wires { let row_pis_raw: [Vec; L] = (0..L) - .map(|_| c.add_virtual_targets(num_query_io::())) + .map(|_| c.add_virtual_targets(query_pi_len::())) .collect_vec() .try_into() .unwrap(); @@ -935,7 +942,7 @@ mod tests { // test function for this revelation circuit. If `distinct` is true, then the // results are enforced to be distinct - async fn test_revelation_unproven_offset_circuit() { + async fn test_revelation_unproven_offset_circuit(distinct: bool) { const ROW_TREE_MAX_DEPTH: usize = 10; const INDEX_TREE_MAX_DEPTH: usize = 10; const L: usize = 5; @@ -943,51 +950,44 @@ mod tests { const PH: usize = 10; const PP: usize = 30; let ops = random_aggregation_operations::(); - let mut row_pis = random_aggregation_public_inputs(&ops); + let mut row_pis = QueryProofPublicInputs::sample_from_ops(&ops); let rng = &mut thread_rng(); let mut original_tree_pis = (0..NUM_PREPROCESSING_IO) .map(|_| rng.gen()) .collect::>() .to_fields(); + let index_ids = F::rand_array(); const NUM_PLACEHOLDERS: usize = 5; let test_placeholders = TestPlaceholders::sample(NUM_PLACEHOLDERS); - let (index_ids, computational_hash) = { + let computational_hash = { let row_pi_0 = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let index_ids = row_pi_0.index_ids(); - let computational_hash = row_pi_0.computational_hash(); - - (index_ids, computational_hash) + row_pi_0.computational_hash() }; let placeholder_hash = test_placeholders.query_placeholder_hash; - // set same index_ids, computational hash and placeholder hash for all proofs; set also num matching rows to 1 - // for all proofs + let min_query_primary = test_placeholders.min_query; + let max_query_primary = test_placeholders.max_query; + // set same primary index query bounds, computational hash and placeholder hash for all proofs; + // set also num matching rows to 1 for all proofs row_pis.iter_mut().for_each(|pis| { - let [index_id_range, ch_range, ph_range, count_range] = [ - QueryPublicInputs::IndexIds, - QueryPublicInputs::ComputationalHash, - QueryPublicInputs::PlaceholderHash, - QueryPublicInputs::NumMatching, + let [min_primary_range, max_primary_range, ch_range, ph_range, count_range] = [ + QueryPublicInputsUniversalCircuit::MinPrimary, + QueryPublicInputsUniversalCircuit::MaxPrimary, + QueryPublicInputsUniversalCircuit::ComputationalHash, + QueryPublicInputsUniversalCircuit::PlaceholderHash, + QueryPublicInputsUniversalCircuit::NumMatching, ] .map(QueryProofPublicInputs::::to_range); - pis[index_id_range].copy_from_slice(&index_ids); + pis[min_primary_range].copy_from_slice(&min_query_primary.to_fields()); + pis[max_primary_range].copy_from_slice(&max_query_primary.to_fields()); pis[ch_range].copy_from_slice(&computational_hash.to_fields()); pis[ph_range].copy_from_slice(&placeholder_hash.to_fields()); pis[count_range].copy_from_slice(&[F::ONE]); }); - let index_value_range = - QueryProofPublicInputs::::to_range(QueryPublicInputs::IndexValue); - let hash_range = QueryProofPublicInputs::::to_range(QueryPublicInputs::TreeHash); - let min_query = test_placeholders.min_query; - let max_query = test_placeholders.max_query; - // closure that modifies a set of row public inputs to ensure that the index value lies - // within the query bounds; the new index value set in the public inputs is returned by the closure - let enforce_index_value_in_query_range = |pis: &mut [F], index_value: U256| { - let query_range_size = max_query - min_query + U256::from(1); - let new_index_value = min_query + index_value % query_range_size; - pis[index_value_range.clone()].copy_from_slice(&new_index_value.to_fields()); - assert!(new_index_value >= min_query && new_index_value <= max_query); - new_index_value - }; + let hash_range = + QueryProofPublicInputs::::to_range(QueryPublicInputsUniversalCircuit::TreeHash); + let index_value_range = QueryProofPublicInputs::::to_range( + QueryPublicInputsUniversalCircuit::PrimaryIndexValue, + ); // build a test tree containing the rows 0..5 found in row_pis // Index tree: // A @@ -1004,7 +1004,7 @@ mod tests { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[1]); let embedded_tree_hash = HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1019,10 +1019,10 @@ mod tests { row_pis[1][hash_range.clone()].copy_from_slice(&node_1_hash.to_fields()); let node_0 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let embedded_tree_hash = HashOutput::try_from(row_pi.tree_hash().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(row_pi.tree_hash()); + let node_value = row_pi.secondary_index_value(); // left child is node 1 - let left_child_hash = HashOutput::try_from(node_1_hash.to_bytes()).unwrap(); + let left_child_hash = HashOutput::from(node_1_hash); NodeInfo::new( &embedded_tree_hash, Some(&left_child_hash), @@ -1034,9 +1034,8 @@ mod tests { }; let node_2 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); - let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1051,9 +1050,8 @@ mod tests { row_pis[2][hash_range.clone()].copy_from_slice(&node_2_hash.to_fields()); let node_4 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[4]); - let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1068,8 +1066,7 @@ mod tests { row_pis[4][hash_range.clone()].copy_from_slice(&node_4_hash.to_fields()); let node_5 = { // can use all dummy values for this node, since there is no proof associated to it - let embedded_tree_hash = - HashOutput::try_from(gen_random_field_hash::().to_bytes()).unwrap(); + let embedded_tree_hash = HashOutput::from(gen_random_field_hash::()); let [node_value, node_min, node_max] = array::from_fn(|_| gen_random_u256(rng)); NodeInfo::new( &embedded_tree_hash, @@ -1080,13 +1077,12 @@ mod tests { node_max, ) }; - let node_4_hash = HashOutput::try_from(node_4_hash.to_bytes()).unwrap(); - let node_5_hash = - HashOutput::try_from(node_5.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); + let node_4_hash = HashOutput::from(node_4_hash); + let node_5_hash = HashOutput::from(node_5.compute_node_hash(index_ids[1])); let node_3 = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); - let embedded_tree_hash = HashOutput::try_from(row_pi.tree_hash().to_bytes()).unwrap(); - let node_value = row_pi.min_value(); + let embedded_tree_hash = HashOutput::from(row_pi.tree_hash()); + let node_value = row_pi.secondary_index_value(); NodeInfo::new( &embedded_tree_hash, Some(&node_4_hash), // left child is node 4 @@ -1098,10 +1094,8 @@ mod tests { }; let node_b = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[2]); - let embedded_tree_hash = - HashOutput::try_from(node_2.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[2], index_value); + let embedded_tree_hash = HashOutput::from(node_2.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); NodeInfo::new( &embedded_tree_hash, None, @@ -1112,13 +1106,12 @@ mod tests { ) }; let node_c = { - let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[4]); - let embedded_tree_hash = - HashOutput::try_from(node_3.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[4], index_value); - // we need also to set index value PI in row_pis[3] to the same value of row_pis[4], as they are in the same index tree - row_pis[3][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); + let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[3]); + let embedded_tree_hash = HashOutput::from(node_3.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); + // we need to set index value in `row_pis[4]` to the same value of `row_pis[3]`, as + // they are in the same index tree + row_pis[4][index_value_range.clone()].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, None, @@ -1128,17 +1121,14 @@ mod tests { node_value, ) }; - let node_b_hash = - HashOutput::try_from(node_b.compute_node_hash(index_ids[0]).to_bytes()).unwrap(); - let node_c_hash = - HashOutput::try_from(node_c.compute_node_hash(index_ids[0]).to_bytes()).unwrap(); + let node_b_hash = HashOutput::from(node_b.compute_node_hash(index_ids[0])); + let node_c_hash = HashOutput::from(node_c.compute_node_hash(index_ids[0])); let node_a = { let row_pi = QueryProofPublicInputs::<_, S>::from_slice(&row_pis[0]); - let embedded_tree_hash = - HashOutput::try_from(node_0.compute_node_hash(index_ids[1]).to_bytes()).unwrap(); - let index_value = row_pi.index_value(); - let node_value = enforce_index_value_in_query_range(&mut row_pis[0], index_value); - // we need also to set index value PI in row_pis[1] to the same value of row_pis[0], as they are in the same index tree + let embedded_tree_hash = HashOutput::from(node_0.compute_node_hash(index_ids[1])); + let node_value = row_pi.primary_index_value(); + // we need to set index value in `row_pis[1]` to the same value of `row_pis[0]`, as + // they are in the same index tree row_pis[1][index_value_range].copy_from_slice(&node_value.to_fields()); NodeInfo::new( &embedded_tree_hash, @@ -1155,8 +1145,17 @@ mod tests { // sample final results and set order-agnostic digests in row_pis proofs accordingly const NUM_ACTUAL_ITEMS_PER_OUTPUT: usize = 4; - let mut results: [[U256; NUM_ACTUAL_ITEMS_PER_OUTPUT]; L] = - array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))); + let mut results: [[U256; NUM_ACTUAL_ITEMS_PER_OUTPUT]; L] = if distinct { + // generate all the output values distinct from each other; generating at + // random will make them distinct with overwhelming probability + array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))) + } else { + // generate some values which are the same + let mut res = array::from_fn(|_| array::from_fn(|_| gen_random_u256(rng))); + res[L - 1] = res[0]; + res + }; + // sort them to ensure that DISTINCT constraints are satisfied results.sort_by(|a, b| { let (is_smaller, is_eq) = is_less_than_or_equal_to_u256_arr(a, b); @@ -1202,8 +1201,9 @@ mod tests { .await; row_pis.iter_mut().zip(digests).for_each(|(pis, digest)| { - let values_range = - QueryProofPublicInputs::::to_range(QueryPublicInputs::OutputValues); + let values_range = QueryProofPublicInputs::::to_range( + QueryPublicInputsUniversalCircuit::OutputValues, + ); pis[values_range.start..values_range.start + CURVE_TARGET_LEN] .copy_from_slice(&digest.to_fields()) }); @@ -1254,11 +1254,10 @@ mod tests { TestRevelationCircuit:: { circuit: RevelationCircuit::new( [row_path_0, row_path_1, row_path_2, row_path_3, row_path_4], + index_ids, &ids, results.map(|res| res.to_vec()), - 0, - 0, - false, + TabularQueryOutputModifiers::new(0, 0, false), test_placeholders.check_placeholder_inputs, ) .unwrap(), @@ -1271,11 +1270,11 @@ mod tests { #[tokio::test] async fn test_revelation_unproven_offset_circuit_no_distinct() { - test_revelation_unproven_offset_circuit().await + test_revelation_unproven_offset_circuit(false).await } #[tokio::test] async fn test_revelation_unproven_offset_circuit_distinct() { - test_revelation_unproven_offset_circuit().await + test_revelation_unproven_offset_circuit(true).await } } diff --git a/verifiable-db/src/revelation/revelation_without_results_tree.rs b/verifiable-db/src/revelation/revelation_without_results_tree.rs index 77af8f9d7..45ee1bb91 100644 --- a/verifiable-db/src/revelation/revelation_without_results_tree.rs +++ b/verifiable-db/src/revelation/revelation_without_results_tree.rs @@ -3,8 +3,8 @@ use crate::{ ivc::PublicInputs as OriginalTreePublicInputs, query::{ - computational_hash_ids::AggregationOperation, - public_inputs::PublicInputs as QueryProofPublicInputs, + computational_hash_ids::AggregationOperation, pi_len as query_pi_len, + public_inputs::PublicInputsQueryCircuits as QueryProofPublicInputs, }, revelation::PublicInputs, }; @@ -43,7 +43,7 @@ use recursion_framework::{ use serde::{Deserialize, Serialize}; use super::{ - num_query_io, pi_len as revelation_pi_len, + pi_len as revelation_pi_len, placeholders_check::{CheckPlaceholderGadget, CheckPlaceholderInputWires}, NUM_PREPROCESSING_IO, }; @@ -92,30 +92,31 @@ where .map(|op| b.constant(op.to_field())); // Convert the entry count to an Uint256. - let entry_count = query_proof.num_matching_rows_target(); - let entry_count = UInt256Target::new_from_target(b, entry_count); + let entry_count = UInt256Target::new_from_target(b, query_proof.num_matching_rows_target()); // Compute the output results array, and deal with AVG and COUNT operations if any. - let ops = query_proof.operation_ids_target(); - assert_eq!(ops.len(), S); let mut results = Vec::with_capacity(L * S); // flag to determine whether entry count is zero let is_entry_count_zero = b.add_virtual_bool_target_unsafe(); - ops.into_iter().enumerate().for_each(|(i, op)| { - let is_op_avg = b.is_equal(op, op_avg); - let is_op_count = b.is_equal(op, op_count); - let result = query_proof.value_target_at_index(i); + query_proof + .operation_ids_target() + .into_iter() + .enumerate() + .for_each(|(i, op)| { + let is_op_avg = b.is_equal(op, op_avg); + let is_op_count = b.is_equal(op, op_count); + let result = query_proof.value_target_at_index(i); - // Compute the AVG result (and it's set to zero if the divisor is zero). - let (avg_result, _, is_divisor_zero) = b.div_u256(&result, &entry_count); + // Compute the AVG result (and it's set to zero if the divisor is zero). + let (avg_result, _, is_divisor_zero) = b.div_u256(&result, &entry_count); - let result = b.select_u256(is_op_avg, &avg_result, &result); - let result = b.select_u256(is_op_count, &entry_count, &result); + let result = b.select_u256(is_op_avg, &avg_result, &result); + let result = b.select_u256(is_op_count, &entry_count, &result); - b.connect(is_divisor_zero.target, is_entry_count_zero.target); + b.connect(is_divisor_zero.target, is_entry_count_zero.target); - results.push(result); - }); + results.push(result); + }); results.resize(L * S, u256_zero); // Pre-compute the final placeholder hash then check it in the @@ -125,8 +126,8 @@ where .placeholder_hash_target() .to_targets() .into_iter() - .chain(query_proof.min_query_target().to_targets()) - .chain(query_proof.max_query_target().to_targets()) + .chain(query_proof.min_primary_target().to_targets()) + .chain(query_proof.max_primary_target().to_targets()) .collect(); let final_placeholder_hash = b.hash_n_to_hash_no_pad::(inputs); @@ -145,7 +146,8 @@ where // hash to the computational hash: // H(pQ.C || placeholder_ids_hash || pQ.M) let inputs = query_proof - .to_computational_hash_raw() + .computational_hash_target() + .to_targets() .iter() .chain(&check_placeholder_wires.placeholder_id_hash.to_targets()) .chain(original_tree_proof.metadata_hash()) @@ -165,6 +167,97 @@ where let flat_computational_hash = flatten_poseidon_hash_target(b, computational_hash); + // additional constraints on boundary rows to ensure completeness of proven rows + // (i.e., that we look at all the rows with primary and secondary index values in the query range) + + let left_boundary_row = query_proof.left_boundary_row_target(); + + // 1. Either the index tree node of left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_primary + let smaller_than_min_primary = b.is_less_than_u256( + &left_boundary_row.index_node_info.predecessor_info.value, + &query_proof.min_primary_target(), + ); + // assert not pQ.left_boundary_row.index_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.index_node_data.predecessor_value < pQ.MIN_primary + let constraint = b.and( + left_boundary_row.index_node_info.predecessor_info.is_found, + smaller_than_min_primary, + ); + b.connect( + left_boundary_row + .index_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + // 2. Either the rows tree node storing left boundary row has no predecessor, or + // the value of the predecessor is smaller than MIN_secondary + let smaller_than_min_secondary = b.is_less_than_u256( + &left_boundary_row.row_node_info.predecessor_info.value, + &query_proof.min_secondary_target(), + ); + // assert not pQ.left_boundary_row.row_node_data.predecessor_info.is_found or + // pQ.left_boundary_row.row_node_data.predecessor_value < pQ.MIN_secondary + let constraint = b.and( + left_boundary_row.row_node_info.predecessor_info.is_found, + smaller_than_min_secondary, + ); + b.connect( + left_boundary_row + .row_node_info + .predecessor_info + .is_found + .target, + constraint.target, + ); + + let right_boundary_row = query_proof.right_boundary_row_target(); + + // 3. Either the index tree node of right boundary row has no successor, or + // the value of the successor is greater than MAX_primary + let greater_than_max_primary = b.is_greater_than_u256( + &right_boundary_row.index_node_info.successor_info.value, + &query_proof.max_primary_target(), + ); + // assert not pQ.right_boundary_row.index_node_data.successor_info.is_found or + // pQ.right_boundary_row.index_node_data.successor_value > pQ.MAX_primary + let constraint = b.and( + right_boundary_row.index_node_info.successor_info.is_found, + greater_than_max_primary, + ); + b.connect( + right_boundary_row + .index_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + + // 4. Either the rows tree node storing right boundary row has no successor, or + // the value of the successor is greater than MAX_secondary + let greater_than_max_secondary = b.is_greater_than_u256( + &right_boundary_row.row_node_info.successor_info.value, + &query_proof.max_secondary_target(), + ); + // assert not pQ.right_boundary_row.row_node_data.successor_info.is_found or + // pQ.right_boundary_row.row_node_data.successor_value > pQ.MAX_secondary + let constraint = b.and( + right_boundary_row.row_node_info.successor_info.is_found, + greater_than_max_secondary, + ); + b.connect( + right_boundary_row + .row_node_info + .successor_info + .is_found + .target, + constraint.target, + ); + // Register the public innputs. PublicInputs::<_, L, S, PH>::new( &original_tree_proof.block_hash(), @@ -188,7 +281,7 @@ where } } - fn assign( + pub(crate) fn assign( &self, pw: &mut PartialWitness, wires: &RevelationWithoutResultsTreeWires, @@ -224,7 +317,7 @@ impl CircuitLo for RecursiveCircuitWires where [(); S - 1]:, - [(); num_query_io::()]:, + [(); query_pi_len::()]:, [(); >::HASH_SIZE]:, { type CircuitBuilderParams = CircuitBuilderParams; @@ -239,7 +332,7 @@ where builder_parameters: Self::CircuitBuilderParams, ) -> Self { let query_verifier = - RecursiveCircuitsVerifierGagdet::() }>::new( + RecursiveCircuitsVerifierGagdet::() }>::new( default_config(), &builder_parameters.query_circuit_set, ); @@ -253,13 +346,14 @@ where builder, &builder_parameters.preprocessing_vk, ); - let query_pi = QueryProofPublicInputs::from_slice( - query_verifier.get_public_input_targets::() }>(), - ); let preprocessing_pi = OriginalTreePublicInputs::from_slice(&preprocessing_proof.public_inputs); - let revelation_circuit = - RevelationWithoutResultsTreeCircuit::build(builder, &query_pi, &preprocessing_pi); + let revelation_circuit = { + let query_pi = QueryProofPublicInputs::from_slice( + query_verifier.get_public_input_targets::() }>(), + ); + RevelationWithoutResultsTreeCircuit::build(builder, &query_pi, &preprocessing_pi) + }; Self { revelation_circuit, @@ -280,20 +374,52 @@ where #[cfg(test)] mod tests { - use super::*; + use std::array; + + use alloy::primitives::U256; + use itertools::Itertools; + use mp2_common::{ + array::ToField, + poseidon::{flatten_poseidon_hash_value, H}, + types::CBuilder, + utils::{FromFields, ToFields}, + C, D, F, + }; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{ + field::types::Field, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, + plonk::config::Hasher, + }; + use rand::{seq::SliceRandom, thread_rng, Rng}; + use crate::{ - query::public_inputs::QueryPublicInputs, - revelation::tests::{compute_results_from_query_proof, TestPlaceholders}, + ivc::PublicInputs as OriginalTreePublicInputs, + query::{ + computational_hash_ids::AggregationOperation, + public_inputs::{ + PublicInputsQueryCircuits as QueryProofPublicInputs, QueryPublicInputs, + }, + universal_circuit::{ + universal_circuit_inputs::Placeholders, universal_query_gadget::OutputValues, + }, + utils::{QueryBoundSource, QueryBounds}, + }, + revelation::{ + revelation_without_results_tree::{ + RevelationWithoutResultsTreeCircuit, RevelationWithoutResultsTreeWires, + }, + tests::{compute_results_from_query_proof_outputs, TestPlaceholders}, + PublicInputs, NUM_PREPROCESSING_IO, + }, test_utils::{ - random_aggregation_operations, random_aggregation_public_inputs, - random_original_tree_proof, + random_aggregation_operations, random_original_tree_proof, + sample_boundary_rows_for_revelation, }, }; - use alloy::primitives::U256; - use mp2_common::{poseidon::flatten_poseidon_hash_value, utils::ToFields, C, D}; - use mp2_test::circuit::{run_circuit, UserCircuit}; - use plonky2::{field::types::Field, plonk::config::Hasher}; - use rand::{prelude::SliceRandom, thread_rng, Rng}; // L: maximum number of results // S: maximum number of items in each result @@ -305,9 +431,9 @@ mod tests { const PP: usize = 20; // Real number of the placeholders - const NUM_PLACEHOLDERS: usize = 5; + const NUM_PLACEHOLDERS: usize = 6; - const QUERY_PI_LEN: usize = crate::query::pi_len::(); + const QUERY_PI_LEN: usize = QueryProofPublicInputs::::total_len(); impl From<&TestPlaceholders> for RevelationWithoutResultsTreeCircuit { fn from(test_placeholders: &TestPlaceholders) -> Self { @@ -318,13 +444,13 @@ mod tests { } #[derive(Clone, Debug)] - struct TestRevelationWithoutResultsTreeCircuit<'a> { + struct TestRevelationCircuit<'a> { c: RevelationWithoutResultsTreeCircuit, query_proof: &'a [F], original_tree_proof: &'a [F], } - impl UserCircuit for TestRevelationWithoutResultsTreeCircuit<'_> { + impl UserCircuit for TestRevelationCircuit<'_> { // Circuit wires + query proof + original tree proof (IVC proof) type Wires = ( RevelationWithoutResultsTreeWires, @@ -357,21 +483,26 @@ mod tests { ops: &[F; S], test_placeholders: &TestPlaceholders, ) -> Vec { - let [mut proof] = random_aggregation_public_inputs(ops); - - let [count_range, min_query_range, max_query_range, p_hash_range] = [ - QueryPublicInputs::NumMatching, - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - QueryPublicInputs::PlaceholderHash, - ] - .map(QueryProofPublicInputs::::to_range); + let [mut proof] = QueryProofPublicInputs::sample_from_ops(ops); + + let [count_range, min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, left_row_range, right_row_range] = + [ + QueryPublicInputs::NumMatching, + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, + QueryPublicInputs::PlaceholderHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, + ] + .map(QueryProofPublicInputs::::to_range); // Set the count, minimum, maximum query and the placeholder hash. [ (count_range, vec![entry_count.to_field()]), - (min_query_range, test_placeholders.min_query.to_fields()), - (max_query_range, test_placeholders.max_query.to_fields()), + (min_query_primary, test_placeholders.min_query.to_fields()), + (max_query_primary, test_placeholders.max_query.to_fields()), ( p_hash_range, test_placeholders.query_placeholder_hash.to_fields(), @@ -380,11 +511,29 @@ mod tests { .into_iter() .for_each(|(range, fields)| proof[range].copy_from_slice(&fields)); + // Set boundary rows to satisfy constraints for completeness + let rng = &mut thread_rng(); + let min_secondary = U256::from_fields(&proof[min_query_secondary]); + let max_secondary = U256::from_fields(&proof[max_query_secondary]); + let placeholders = + Placeholders::new_empty(test_placeholders.min_query, test_placeholders.max_query); + let query_bounds = QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap(); + let (left_boundary_row, right_boundary_row) = + sample_boundary_rows_for_revelation(&query_bounds, rng); + + proof[left_row_range].copy_from_slice(&left_boundary_row.to_fields()); + proof[right_row_range].copy_from_slice(&right_boundary_row.to_fields()); + proof } /// Utility function for testing the revelation circuit with results tree - fn test_revelation_without_results_tree_circuit(ops: &[F; S], entry_count: Option) { + fn test_revelation_batching_circuit(ops: &[F; S], entry_count: Option) { let rng = &mut thread_rng(); // Generate the testing placeholder data. @@ -396,11 +545,11 @@ mod tests { let query_pi = QueryProofPublicInputs::<_, S>::from_slice(&query_proof); // Generate the original tree proof (IVC proof). - let original_tree_proof = random_original_tree_proof(&query_pi); + let original_tree_proof = random_original_tree_proof(query_pi.tree_hash()); let original_tree_pi = OriginalTreePublicInputs::from_slice(&original_tree_proof); // Construct the test circuit. - let test_circuit = TestRevelationWithoutResultsTreeCircuit { + let test_circuit = TestRevelationCircuit { c: (&test_placeholders).into(), query_proof: &query_proof, original_tree_proof: &original_tree_proof, @@ -410,7 +559,6 @@ mod tests { let proof = run_circuit::(test_circuit); let pi = PublicInputs::<_, L, S, PH>::from_slice(&proof.public_inputs); - // Initialize the overflow flag to false. let entry_count = query_pi.num_matching_rows(); // Check the public inputs. @@ -454,12 +602,16 @@ mod tests { // Entry count assert_eq!(pi.entry_count(), entry_count); // check results - let (result, overflow) = compute_results_from_query_proof(&query_pi); + let result = compute_results_from_query_proof_outputs( + query_pi.num_matching_rows(), + OutputValues::::from_fields(query_pi.to_values_raw()), + &query_pi.operation_ids(), + ); let mut exp_results = [[U256::ZERO; S]; L]; exp_results[0] = result; assert_eq!(pi.result_values(), exp_results); // overflow flag - assert_eq!(pi.overflow_flag(), overflow); + assert_eq!(pi.overflow_flag(), query_pi.overflow_flag()); // Query limit assert_eq!(pi.query_limit(), F::ZERO); // Query offset @@ -467,57 +619,57 @@ mod tests { } #[test] - fn test_revelation_without_results_tree_simple() { + fn test_revelation_batching_simple() { // Generate the random operations and set the first operation to SUM // (not ID which should not be present in the aggregation). let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::SumOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for COUNT operation. #[test] - fn test_revelation_without_results_tree_for_op_count() { + fn test_revelation_batching_for_op_count() { // Set the first operation to COUNT. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::CountOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for AVG operation. #[test] - fn test_revelation_without_results_tree_for_op_avg() { + fn test_revelation_batching_for_op_avg() { // Set the first operation to AVG. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::AvgOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, None); + test_revelation_batching_circuit(&ops, None); } // Test for AVG operation with zero entry count. #[test] - fn test_revelation_without_results_tree_for_op_avg_with_no_entries() { + fn test_revelation_batching_for_op_avg_with_no_entries() { // Set the first operation to AVG. let mut ops: [_; S] = random_aggregation_operations(); ops[0] = AggregationOperation::AvgOp.to_field(); - test_revelation_without_results_tree_circuit(&ops, Some(0)); + test_revelation_batching_circuit(&ops, Some(0)); } // Test for no AVG operation with zero entry count. #[test] - fn test_revelation_without_results_tree_for_no_op_avg_with_no_entries() { + fn test_revelation_batching_for_no_op_avg_with_no_entries() { // Initialize the all operations to SUM or COUNT (not AVG). let mut rng = thread_rng(); - let ops = std::array::from_fn(|_| { + let ops = array::from_fn(|_| { [AggregationOperation::SumOp, AggregationOperation::CountOp] .choose(&mut rng) .unwrap() .to_field() }); - test_revelation_without_results_tree_circuit(&ops, Some(0)); + test_revelation_batching_circuit(&ops, Some(0)); } } diff --git a/verifiable-db/src/test_utils.rs b/verifiable-db/src/test_utils.rs index 02e2b5fbb..864c1451f 100644 --- a/verifiable-db/src/test_utils.rs +++ b/verifiable-db/src/test_utils.rs @@ -3,18 +3,20 @@ use crate::{ ivc::public_inputs::H_RANGE as ORIGINAL_TREE_H_RANGE, query::{ - aggregation::{QueryBounds, QueryHashNonExistenceCircuits}, computational_hash_ids::{ AggregationOperation, ColumnIDs, Identifiers, Operation, PlaceholderIdentifier, }, - pi_len, public_inputs::{ - PublicInputs as QueryPI, PublicInputs as QueryProofPublicInputs, PublicInputs, - QueryPublicInputs, + PublicInputsFactory, PublicInputsQueryCircuits as QueryPI, QueryPublicInputs, }, - universal_circuit::universal_circuit_inputs::{ - BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, + row_chunk_gadgets::BoundaryRowData, + universal_circuit::{ + universal_circuit_inputs::{ + BasicOperation, ColumnCell, InputOperand, OutputItem, Placeholders, ResultStructure, + }, + universal_query_gadget::OutputValues, }, + utils::{QueryBoundSource, QueryBounds, QueryHashNonExistenceCircuits}, }, revelation::NUM_PREPROCESSING_IO, }; @@ -22,13 +24,13 @@ use alloy::primitives::U256; use itertools::Itertools; use mp2_common::{ array::ToField, - types::CURVE_TARGET_LEN, utils::{Fieldable, ToFields}, F, }; +use mp2_test::utils::{gen_random_field_hash, gen_random_u256}; use plonky2::{ field::types::{Field, PrimeField64, Sample}, - hash::hash_types::NUM_HASH_OUT_ELTS, + hash::hash_types::HashOut, plonk::config::GenericHashOut, }; use plonky2_ecgfp5::curve::curve::Point; @@ -51,10 +53,28 @@ pub const ROW_TREE_MAX_DEPTH: usize = 10; pub const INDEX_TREE_MAX_DEPTH: usize = 15; pub const NUM_COLUMNS: usize = 4; +/// Generate a set of values in a given range ensuring that the i+1-th generated value is +/// bigger than the i-th generated value +pub(crate) fn gen_values_in_range( + rng: &mut R, + lower: U256, + upper: U256, +) -> [U256; N] { + assert!(upper >= lower, "{upper} is smaller than {lower}"); + let mut prev_value = lower; + array::from_fn(|_| { + let range = (upper - prev_value).checked_add(U256::from(1)); + let gen_value = match range { + Some(range) => prev_value + gen_random_u256(rng) % range, + None => gen_random_u256(rng), + }; + prev_value = gen_value; + gen_value + }) +} + /// Generate a random original tree proof for testing. -pub fn random_original_tree_proof( - query_pi: &QueryProofPublicInputs, -) -> Vec { +pub fn random_original_tree_proof(tree_hash: HashOut) -> Vec { let mut rng = thread_rng(); let mut proof = (0..NUM_PREPROCESSING_IO) .map(|_| rng.gen()) @@ -62,7 +82,7 @@ pub fn random_original_tree_proof( .to_fields(); // Set the tree hash. - proof[ORIGINAL_TREE_H_RANGE].copy_from_slice(query_pi.to_hash_raw()); + proof[ORIGINAL_TREE_H_RANGE].copy_from_slice(&tree_hash.to_fields()); proof } @@ -84,63 +104,90 @@ pub fn random_aggregation_operations() -> [F; S] { }) } -/// Generate S number of proof public input slices by the specified operations for testing. -/// The each returned proof public inputs could be constructed by -/// `PublicInputs::from_slice` function. -pub fn random_aggregation_public_inputs( - ops: &[F; S], -) -> [Vec; N] { - let [ops_range, overflow_range, index_ids_range, c_hash_range, p_hash_range] = [ - QueryPublicInputs::OpIds, - QueryPublicInputs::Overflow, - QueryPublicInputs::IndexIds, - QueryPublicInputs::ComputationalHash, - QueryPublicInputs::PlaceholderHash, - ] - .map(PublicInputs::::to_range); - - let first_value_start = PublicInputs::::to_range(QueryPublicInputs::OutputValues).start; - let is_first_op_id = - ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - - // Generate the index ids, computational hash and placeholder hash, - // they should be same for a series of public inputs. - let mut rng = thread_rng(); - let index_ids = (0..2).map(|_| rng.gen()).collect::>().to_fields(); - let [computational_hash, placeholder_hash]: [Vec<_>; 2] = array::from_fn(|_| { - (0..NUM_HASH_OUT_ELTS) - .map(|_| rng.gen()) - .collect::>() - .to_fields() - }); - - array::from_fn(|_| { - let mut pi = (0..pi_len::()) - .map(|_| rng.gen()) - .collect::>() - .to_fields(); +impl + PublicInputsFactory<'_, F, S, UNIVERSAL_CIRCUIT> +{ + pub(crate) fn sample_from_ops(ops: &[F; S]) -> [Vec; NUM_INPUTS] + where + [(); S - 1]:, + { + let rng = &mut thread_rng(); - // Copy the specified operations to the proofs. - pi[ops_range.clone()].copy_from_slice(ops); + let tree_hash = gen_random_field_hash(); + let computational_hash = gen_random_field_hash(); + let placeholder_hash = gen_random_field_hash(); + let [min_primary, max_primary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); + let [min_secondary, max_secondary] = gen_values_in_range(rng, U256::ZERO, U256::MAX); - // Set the overflow flag to a random boolean. - let overflow = F::from_bool(rng.gen()); - pi[overflow_range.clone()].copy_from_slice(&[overflow]); + let query_bounds = { + let placeholders = Placeholders::new_empty(min_primary, max_primary); + QueryBounds::new( + &placeholders, + Some(QueryBoundSource::Constant(min_secondary)), + Some(QueryBoundSource::Constant(max_secondary)), + ) + .unwrap() + }; - // Set the index ids, computational hash and placeholder hash, - pi[index_ids_range.clone()].copy_from_slice(&index_ids); - pi[c_hash_range.clone()].copy_from_slice(&computational_hash); - pi[p_hash_range.clone()].copy_from_slice(&placeholder_hash); + let is_first_op_id = + ops[0] == Identifiers::AggregationOperations(AggregationOperation::IdOp).to_field(); - // If the first operation is ID, set the value to a random point. - if is_first_op_id { - let first_value = Point::sample(&mut rng).to_weierstrass().to_fields(); - pi[first_value_start..first_value_start + CURVE_TARGET_LEN] - .copy_from_slice(&first_value); - } + let mut previous_row: Option = None; + array::from_fn(|_| { + // generate output values + let output_values = if is_first_op_id { + // generate random curve point + OutputValues::::new_outputs_no_aggregation(&Point::sample(rng)) + } else { + let values = (0..S).map(|_| gen_random_u256(rng)).collect_vec(); + OutputValues::::new_aggregation_outputs(&values) + }; + // generate random count and overflow flag + let count = F::from_canonical_u32(rng.gen()); + let overflow = F::from_bool(rng.gen()); + // generate boundary rows + let left_boundary_row = if let Some(row) = &previous_row { + row.sample_consecutive_row(rng, &query_bounds) + } else { + BoundaryRowData::sample(rng, &query_bounds) + }; + let right_boundary_row = BoundaryRowData::sample(rng, &query_bounds); + assert!( + left_boundary_row.index_node_info.predecessor_info.value >= min_primary + && left_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + left_boundary_row.index_node_info.successor_info.value >= min_primary + && left_boundary_row.index_node_info.successor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.predecessor_info.value >= min_primary + && right_boundary_row.index_node_info.predecessor_info.value <= max_primary + ); + assert!( + right_boundary_row.index_node_info.successor_info.value >= min_primary + && right_boundary_row.index_node_info.successor_info.value <= max_primary + ); + previous_row = Some(right_boundary_row.clone()); - pi - }) + PublicInputsFactory::::new( + &tree_hash.to_fields(), + &output_values.to_fields(), + &[count], + ops, + &left_boundary_row.to_fields(), + &right_boundary_row.to_fields(), + &min_primary.to_fields(), + &max_primary.to_fields(), + &min_secondary.to_fields(), + &max_secondary.to_fields(), + &[overflow], + &computational_hash.to_fields(), + &placeholder_hash.to_fields(), + ) + .to_vec() + }) + } } /// Revelation related data used for testing @@ -249,36 +296,64 @@ impl TestRevelationData { let computational_hash = non_existence_circuits.computational_hash(); let placeholder_hash = non_existence_circuits.placeholder_hash(); - let [mut query_pi_raw] = random_aggregation_public_inputs::<1, MAX_NUM_ITEMS_PER_OUTPUT>( - &ops_ids.try_into().unwrap(), - ); - let [min_query_range, max_query_range, p_hash_range, c_hash_range] = [ - QueryPublicInputs::MinQuery, - QueryPublicInputs::MaxQuery, - QueryPublicInputs::PlaceholderHash, - QueryPublicInputs::ComputationalHash, - ] - .map(QueryPI::::to_range); + let [mut query_pi_raw] = + QueryPI::::sample_from_ops(&ops_ids.try_into().unwrap()); + let [min_query_primary, max_query_primary, min_query_secondary, max_query_secondary, p_hash_range, c_hash_range, left_row_range, right_row_range] = + [ + QueryPublicInputs::MinPrimary, + QueryPublicInputs::MaxPrimary, + QueryPublicInputs::MinSecondary, + QueryPublicInputs::MaxSecondary, + QueryPublicInputs::PlaceholderHash, + QueryPublicInputs::ComputationalHash, + QueryPublicInputs::LeftBoundaryRow, + QueryPublicInputs::RightBoundaryRow, + ] + .map(QueryPI::::to_range); + + // sample left boundary row and right boundary row to satisfy revelation circuit constraints + let (left_boundary_row, right_boundary_row) = + sample_boundary_rows_for_revelation(&query_bounds, rng); // Set the minimum, maximum query, placeholder hash andn computational hash to expected values. [ ( - min_query_range, + min_query_primary, query_bounds.min_query_primary().to_fields(), ), ( - max_query_range, + max_query_primary, query_bounds.max_query_primary().to_fields(), ), + ( + min_query_secondary, + query_bounds.min_query_secondary().value().to_fields(), + ), + ( + max_query_secondary, + query_bounds.max_query_secondary().value().to_fields(), + ), (p_hash_range, placeholder_hash.to_vec()), (c_hash_range, computational_hash.to_vec()), + (left_row_range, left_boundary_row.to_fields()), + (right_row_range, right_boundary_row.to_fields()), ] .into_iter() .for_each(|(range, fields)| query_pi_raw[range].copy_from_slice(&fields)); let query_pi = QueryPI::::from_slice(&query_pi_raw); + assert_eq!(query_pi.min_primary(), query_bounds.min_query_primary(),); + assert_eq!(query_pi.max_primary(), query_bounds.max_query_primary(),); + assert_eq!( + query_pi.min_secondary(), + query_bounds.min_query_secondary().value, + ); + assert_eq!( + query_pi.max_secondary(), + query_bounds.max_query_secondary().value, + ); // generate preprocessing proof public inputs - let preprocessing_pi_raw = random_original_tree_proof(&query_pi); + let preprocessing_pi_raw = random_original_tree_proof(query_pi.tree_hash()); Self { query_bounds, @@ -318,3 +393,56 @@ impl TestRevelationData { &self.query_pi_raw } } + +pub(crate) fn sample_boundary_rows_for_revelation( + query_bounds: &QueryBounds, + rng: &mut R, +) -> (BoundaryRowData, BoundaryRowData) { + let min_secondary = *query_bounds.min_query_secondary().value(); + let max_secondary = *query_bounds.max_query_secondary().value(); + let mut left_boundary_row = BoundaryRowData::sample(rng, query_bounds); + // for predecessor of `left_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.min_query_primary() == U256::ZERO { + left_boundary_row.index_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = gen_values_in_range( + rng, + U256::ZERO, + query_bounds.min_query_primary() - U256::from(1), + ); + left_boundary_row.index_node_info.predecessor_info.value = predecessor_value; + } + // for predecessor of `left_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || min_secondary == U256::ZERO { + left_boundary_row.row_node_info.predecessor_info.is_found = false; + } else { + let [predecessor_value] = + gen_values_in_range(rng, U256::ZERO, min_secondary - U256::from(1)); + left_boundary_row.row_node_info.predecessor_info.value = predecessor_value; + } + let mut right_boundary_row = BoundaryRowData::sample(rng, query_bounds); + // for successor of `right_boundary_row` in index tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || query_bounds.max_query_primary() == U256::MAX { + right_boundary_row.index_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range( + rng, + query_bounds.max_query_primary() + U256::from(1), + U256::MAX, + ); + right_boundary_row.index_node_info.successor_info.value = successor_value; + } + // for successor of `right_boundary_row` in rows tree, we need to either mark it as + // non-existent or to make its value out of range + if rng.gen() || max_secondary == U256::MAX { + right_boundary_row.row_node_info.successor_info.is_found = false; + } else { + let [successor_value] = gen_values_in_range(rng, max_secondary + U256::from(1), U256::MAX); + right_boundary_row.row_node_info.successor_info.value = successor_value; + } + + (left_boundary_row, right_boundary_row) +}