diff --git a/cli/Cargo.lock b/cli/Cargo.lock index ca78dec5..78b28aef 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -199,6 +199,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "autocfg" version = "1.1.0" @@ -842,9 +848,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" dependencies = [ "either", ] @@ -877,13 +883,10 @@ dependencies = [ name = "lace" version = "0.6.0" dependencies = [ - "bincode", "ctrlc", "dirs", - "flate2", "indexmap", "indicatif", - "itertools", "lace_cc", "lace_codebook", "lace_consts", @@ -892,7 +895,6 @@ dependencies = [ "lace_metadata", "lace_stats", "lace_utils", - "log", "maplit", "num", "polars", @@ -900,11 +902,9 @@ dependencies = [ "rand_distr", "rand_xoshiro", "rayon", - "regex", "serde", "serde_json", "serde_yaml", - "special", "thiserror", ] @@ -927,7 +927,6 @@ name = "lace_cc" version = "0.5.0" dependencies = [ "enum_dispatch", - "indicatif", "itertools", "lace_codebook", "lace_consts", @@ -947,15 +946,11 @@ dependencies = [ name = "lace_codebook" version = "0.5.0" dependencies = [ - "flate2", "lace_consts", "lace_data", "lace_stats", "lace_utils", - "maplit", "polars", - "rand", - "rayon", "serde", "serde_yaml", "thiserror", @@ -973,7 +968,6 @@ name = "lace_data" version = "0.2.0" dependencies = [ "lace_utils", - "regex", "serde", "thiserror", ] @@ -995,21 +989,18 @@ name = "lace_metadata" version = "0.5.0" dependencies = [ "bincode", - "dirs", "hex", "lace_cc", "lace_codebook", "lace_data", "lace_stats", "log", - "once_cell", "rand_xoshiro", "rayon", "serde", "serde_json", "serde_yaml", "thiserror", - "toml", ] [[package]] @@ -1021,8 +1012,6 @@ dependencies = [ "lace_data", "lace_utils", "rand", - "rand_xoshiro", - "regex", "serde", "special", "thiserror", @@ -1041,15 +1030,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -[[package]] -name = "lexical" -version = "6.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7aefb36fd43fef7003334742cbf77b243fcd36418a1d1bdd480d613a67968f6" -dependencies = [ - "lexical-core", -] - [[package]] name = "lexical-core" version = "0.8.5" @@ -1442,25 +1422,6 @@ dependencies = [ "futures", ] -[[package]] -name = "parquet2" -version = "0.17.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579fe5745f02cef3d5f236bfed216fd4693e49e4e920a13475c6132233283bce" -dependencies = [ - "async-stream", - "brotli", - "flate2", - "futures", - "lz4", - "parquet-format-safe", - "seq-macro", - "snap", - "streaming-decompression", - "xxhash-rust", - "zstd 0.12.4", -] - [[package]] name = "paste" version = "1.0.14" @@ -1526,9 +1487,9 @@ dependencies = [ [[package]] name = "polars" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40db657cc67a8dd9fe4b40db5b73027f5f224623545597e1930cbbb9c05b1de5" +checksum = "938048fcda6a8e2ace6eb168bee1b415a92423ce51e418b853bf08fc40349b6b" dependencies = [ "getrandom", "polars-core", @@ -1541,41 +1502,55 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1e50c63db77f846ac5119477422f0156f0a1826ceaae7d921f9a6d5ea5f7ca3" +checksum = "ce68a02f698ff7787c261aea1b4c040a8fe183a8fb200e2436d7f35d95a1b86f" dependencies = [ "ahash", "arrow-format", - "base64", + "atoi_simd", "bytemuck", "chrono", "dyn-clone", "either", "ethnum", - "fallible-streaming-iterator", + "fast-float", "foreign_vec", "futures", "getrandom", "hashbrown 0.14.3", - "lexical-core", + "itoa", "lz4", "multiversion", "num-traits", - "parquet2", "polars-error", - "rustc_version", + "polars-utils", + "ryu", "simdutf8", "streaming-iterator", "strength_reduce", - "zstd 0.13.0", + "version_check", + "zstd", +] + +[[package]] +name = "polars-compute" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14fbc5f141b29b656a4cec4802632e5bff10bf801c6809c6bbfbd4078a044dd" +dependencies = [ + "bytemuck", + "num-traits", + "polars-arrow", + "polars-utils", + "version_check", ] [[package]] name = "polars-core" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdfb622b8ca81b4614c64d95e7590d6e0571d7d398b5ad595c1abc4412abe714" +checksum = "d0f5efe734b6cbe5f97ea769be8360df5324fade396f1f3f5ad7fe9360ca4a23" dependencies = [ "ahash", "bitflags 2.4.2", @@ -1587,6 +1562,7 @@ dependencies = [ "num-traits", "once_cell", "polars-arrow", + "polars-compute", "polars-error", "polars-row", "polars-utils", @@ -1602,12 +1578,11 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6480520ebde0b20935b600483b865513891e36c04942cebdd19e4f338257b4" +checksum = "6396de788f99ebfc9968e7b6f523e23000506cde4ba6dfc62ae4ce949002a886" dependencies = [ "arrow-format", - "parquet2", "regex", "simdutf8", "thiserror", @@ -1615,20 +1590,19 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666466a3b151047c76d99b4e4e5f5438895ef97848008cf49b06df8e3d2d395a" +checksum = "7d0458efe8946f4718fd352f230c0db5a37926bd0d2bd25af79dc24746abaaea" dependencies = [ "ahash", "async-trait", + "atoi_simd", "bytes", "fast-float", "flate2", "futures", "home", "itoa", - "lexical", - "lexical-core", "memchr", "memmap2", "num-traits", @@ -1638,6 +1612,7 @@ dependencies = [ "polars-core", "polars-error", "polars-json", + "polars-parquet", "polars-utils", "rayon", "regex", @@ -1648,13 +1623,14 @@ dependencies = [ "smartstring", "tokio", "tokio-util", + "zstd", ] [[package]] name = "polars-json" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24451d2647a9bd51283cc946509c23bac27130565daa5103a156c8507b85b5a3" +checksum = "ea47d46b7a98fa683ef235ad48b783abf61734828e754096cfbdc77404fff9b3" dependencies = [ "ahash", "chrono", @@ -1673,9 +1649,9 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e1c2da1ca20106f80d9510090344e7311fd1dcfd6e6b65031e10606c0958c7" +checksum = "9d7105b40905bb38e8fc4a7fd736594b7491baa12fad3ac492969ca221a1b5d5" dependencies = [ "ahash", "bitflags 2.4.2", @@ -1697,9 +1673,9 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fe2d37a6a3ef358499d43aecee80740e62dd44e6cfe7a9c4aa0b8db88de8292" +checksum = "2e09afc456ab11e75e5dcb43e00a01c71f3a46a2781e450054acb6bb096ca78e" dependencies = [ "ahash", "argminmax", @@ -1710,6 +1686,7 @@ dependencies = [ "memchr", "num-traits", "polars-arrow", + "polars-compute", "polars-core", "polars-error", "polars-utils", @@ -1719,11 +1696,37 @@ dependencies = [ "version_check", ] +[[package]] +name = "polars-parquet" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba24d67b1f64ab85143033dd46fa090b13c0f74acdf91b0780c16aecf005e3d" +dependencies = [ + "ahash", + "async-stream", + "base64", + "brotli", + "ethnum", + "flate2", + "futures", + "lz4", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-error", + "polars-utils", + "seq-macro", + "simdutf8", + "snap", + "streaming-decompression", + "zstd", +] + [[package]] name = "polars-pipe" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6aa050d529be01617f54bc60658149da76f97dbea9fdac3c9d60b811f64a2ba" +checksum = "d9b7ead073cc3917027d77b59861a9f071db47125de9314f8907db1a0a3e4100" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -1731,6 +1734,7 @@ dependencies = [ "hashbrown 0.14.3", "num-traits", "polars-arrow", + "polars-compute", "polars-core", "polars-io", "polars-ops", @@ -1744,9 +1748,9 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c47e5d62d8f612aab61a6331d04c5c95c9ff301106d8b91131c8833b4ef3def6" +checksum = "384a175624d050c31c473ee11df9d7af5d729ae626375e522158cfb3d150acd0" dependencies = [ "ahash", "bytemuck", @@ -1755,7 +1759,9 @@ dependencies = [ "polars-arrow", "polars-core", "polars-io", + "polars-json", "polars-ops", + "polars-parquet", "polars-time", "polars-utils", "rayon", @@ -1767,9 +1773,9 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f05d6544f7d6065fcaa93bc69aac0532ce09aab4f81ec03c9a78dd901bb0c05b" +checksum = "32322f7acbb83db3e9c7697dc821be73d06238da89c817dcc8bc1549a5e9c72f" dependencies = [ "polars-arrow", "polars-error", @@ -1778,9 +1784,9 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77f65f9c8bfe7f0b2c08c38c79b92ec4ddaf213fc424d94a6272ed7b2d83987f" +checksum = "9f0b4c6ddffdfd0453e84bc3918572c633014d661d166654399cf93752aa95b5" dependencies = [ "polars-arrow", "polars-core", @@ -1795,9 +1801,9 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3763af36aeeb85ef083f11c43bc28c5b6222e2aae039c5118d916bc855f2b5b9" +checksum = "dee2649fc96bd1b6584e0e4a4b3ca7d22ed3d117a990e63ad438ecb26f7544d0" dependencies = [ "atoi", "chrono", @@ -1814,13 +1820,14 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.34.2" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55d2c038ff67e4eb6019682c3f66d83f744e285de9c28e816109a61bace824cd" +checksum = "b174ca4a77ad47d7b91a0460aaae65bbf874c8bfbaaa5308675dadef3976bbda" dependencies = [ "ahash", "bytemuck", "hashbrown 0.14.3", + "indexmap", "num-traits", "once_cell", "polars-error", @@ -1964,6 +1971,26 @@ dependencies = [ "thiserror", ] +[[package]] +name = "ref-cast" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4846d4c50d1721b1a3bef8af76924eef20d5e723647333798c1b519b3a9473f" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fddb4f8d99b0a2ebafc65a87a69a7b9875e4b1ae1f00db265d300ef7f28bccc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "regex" version = "1.10.2" @@ -1999,15 +2026,6 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "0.38.30" @@ -2029,9 +2047,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "rv" -version = "0.16.2" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ad4c180111696893363158f59ee1b15500947365270b7460a432fc207f926bf" +checksum = "35f602941aca67593b30eea71a0b372e50e3ad63e7aa6b98b2ea18ff74ba9cf8" dependencies = [ "doc-comment", "lru", @@ -2059,12 +2077,6 @@ dependencies = [ "bytemuck", ] -[[package]] -name = "semver" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" - [[package]] name = "seq-macro" version = "0.3.5" @@ -2102,15 +2114,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_spanned" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" -dependencies = [ - "serde", -] - [[package]] name = "serde_yaml" version = "0.9.30" @@ -2139,15 +2142,16 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.12.0" +version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f07a84c7456b901b8dd2c1d44caca8b0fd2c2616206ee5acc9d9da61e8d9ec" +checksum = "2faf8f101b9bc484337a6a6b0409cf76c139f2fb70a9e3aee6b6774be7bfbf76" dependencies = [ "ahash", "getrandom", "halfbrown", "lexical-core", "once_cell", + "ref-cast", "serde", "serde_json", "simdutf8", @@ -2207,9 +2211,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.38.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" +checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" dependencies = [ "log", ] @@ -2284,16 +2288,16 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.11" +version = "0.30.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666" +checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", - "winapi", + "windows", ] [[package]] @@ -2382,40 +2386,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "toml" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.19.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "winnow", -] - [[package]] name = "typenum" version = "1.17.0" @@ -2448,9 +2418,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "value-trait" -version = "0.6.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a5b6c8ceb01263b969cac48d4a6705134d490ded13d889e52c0cfc80c6945e" +checksum = "dad8db98c1e677797df21ba03fca7d3bf9bec3ca38db930954e4fe6e1ea27eb4" dependencies = [ "float-cmp", "halfbrown", @@ -2565,6 +2535,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.0", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -2706,15 +2686,6 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" -[[package]] -name = "winnow" -version = "0.5.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7cf47b659b318dccbd69cc4797a39ae128f533dce7902a1096044d1967b9c16" -dependencies = [ - "memchr", -] - [[package]] name = "xxhash-rust" version = "0.8.8" @@ -2741,32 +2712,13 @@ dependencies = [ "syn 2.0.48", ] -[[package]] -name = "zstd" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" -dependencies = [ - "zstd-safe 6.0.6", -] - [[package]] name = "zstd" version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" dependencies = [ - "zstd-safe 7.0.0", -] - -[[package]] -name = "zstd-safe" -version = "6.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" -dependencies = [ - "libc", - "zstd-sys", + "zstd-safe", ] [[package]] diff --git a/lace/Cargo.lock b/lace/Cargo.lock index 29d0abac..1112966e 100644 --- a/lace/Cargo.lock +++ b/lace/Cargo.lock @@ -2412,9 +2412,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "rv" -version = "0.16.2" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ad4c180111696893363158f59ee1b15500947365270b7460a432fc207f926bf" +checksum = "35f602941aca67593b30eea71a0b372e50e3ad63e7aa6b98b2ea18ff74ba9cf8" dependencies = [ "doc-comment", "lru", diff --git a/lace/lace_consts/Cargo.toml b/lace/lace_consts/Cargo.toml index 2bc7773a..8134be48 100644 --- a/lace/lace_consts/Cargo.toml +++ b/lace/lace_consts/Cargo.toml @@ -9,4 +9,4 @@ repository = "https://github.com/promised-ai/lace" description = "Default constants for Lace" [dependencies] -rv = { version = "0.16.2", features = ["serde1", "arraydist"] } +rv = { version = "0.16.3", features = ["serde1", "arraydist"] } diff --git a/lace/src/interface/mod.rs b/lace/src/interface/mod.rs index b2303bd6..71939d75 100644 --- a/lace/src/interface/mod.rs +++ b/lace/src/interface/mod.rs @@ -15,7 +15,7 @@ pub use oracle::utils; pub use oracle::{ ConditionalEntropyType, DatalessOracle, MiComponents, MiType, Oracle, - OracleT, RowSimilarityVariant, + OracleT, RowSimilarityVariant, Variability, }; pub use given::Given; diff --git a/lace/src/interface/oracle/error.rs b/lace/src/interface/oracle/error.rs index e4538577..e2e17bc7 100644 --- a/lace/src/interface/oracle/error.rs +++ b/lace/src/interface/oracle/error.rs @@ -179,6 +179,17 @@ pub enum PredictError { GivenError(#[from] GivenError), } +/// Describes errors that can occur from bad inputs to `Oracle::variability` +#[derive(Debug, Clone, PartialEq, Error)] +pub enum VariabilityError { + /// The target column index is out of bounds + #[error("Target index error in predict query: {0}")] + IndexError(#[from] IndexError), + /// The Given is invalid + #[error("Invalid predict 'given' argument: {0}")] + GivenError(#[from] GivenError), +} + /// Describes errors that arise from invalid predict uncertainty arguments #[derive(Debug, Clone, PartialEq, Error)] pub enum PredictUncertaintyError { @@ -192,7 +203,7 @@ pub enum PredictUncertaintyError { /// Describes errors from incompatible `col_max_logp` caches #[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum ColumnMaxiumLogPError { +pub enum ColumnMaximumLogPError { /// The state indices used to compute the cache do not match those passed to the function. #[error("The state indices used to compute the cache do not match those passed to the function.")] InvalidStateIndices, @@ -247,7 +258,7 @@ pub enum LogpError { #[error("Invalid logp 'given' argument: {0}")] GivenError(#[from] GivenError), #[error("Invalid `col_max_logps` argument: {0}")] - ColumnMaxiumLogPError(#[from] ColumnMaxiumLogPError), + ColumnMaximumLogPError(#[from] ColumnMaximumLogPError), } /// Describes errors from bad inputs to Oracle::simulate diff --git a/lace/src/interface/oracle/mod.rs b/lace/src/interface/oracle/mod.rs index 473bf335..b9116ac7 100644 --- a/lace/src/interface/oracle/mod.rs +++ b/lace/src/interface/oracle/mod.rs @@ -5,7 +5,7 @@ pub mod utils; mod validation; pub use dataless::DatalessOracle; -pub use traits::OracleT; +pub use traits::{OracleT, Variability}; use std::path::Path; diff --git a/lace/src/interface/oracle/traits.rs b/lace/src/interface/oracle/traits.rs index 4ddab9fc..11e5772b 100644 --- a/lace/src/interface/oracle/traits.rs +++ b/lace/src/interface/oracle/traits.rs @@ -16,6 +16,7 @@ use lace_stats::rv::traits::Rv; use lace_stats::SampleError; use rand::Rng; use rayon::prelude::*; +use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; macro_rules! col_indices_ok { @@ -41,6 +42,25 @@ macro_rules! state_indices_ok { }} } +/// Represents different formalizations of variability in distributions +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Variability { + /// The variance of a univariate distribution + Variance(f64), + /// The entropy of a distribution + Entropy(f64), +} + +impl From for f64 { + fn from(value: Variability) -> Self { + match value { + Variability::Variance(x) => x, + Variability::Entropy(x) => x, + } + } +} + pub trait OracleT: CanOracle { /// Returns the diagnostics for each state fn state_diagnostics(&self) -> Vec { @@ -2046,6 +2066,125 @@ pub trait OracleT: CanOracle { } } + /// Compute the variability of a conditional distribution + /// + /// # Notes + /// - Returns variance for Continuous and Count columns + /// - Returns Entropy for Categorical columns + /// + /// # Arguments + /// - col_ix: the index of the column for which to compute the variability + /// - given: optional observations by which to constrain the prediction + /// - state_ixs_opt: Optional vector of state indices from which to compute, + /// if None, use all states. + fn variability( + &self, + col_ix: Ix, + given: &Given, + state_ixs_opt: Option<&[usize]>, + ) -> Result { + use crate::stats::rv::traits::{Entropy, Variance}; + use crate::stats::MixtureType; + + let states: Vec<&State> = if let Some(state_ixs) = state_ixs_opt { + state_ixs.iter().map(|&ix| &self.states()[ix]).collect() + } else { + self.states().iter().collect() + }; + + let given = + given.clone().canonical(self.codebook()).map_err(|err| { + error::VariabilityError::GivenError( + error::GivenError::IndexError(err), + ) + })?; + + let col_ix = col_ix.col_ix(self.codebook())?; + + // Get the mixture weights for each state + let mut mixture_types: Vec = states + .iter() + .map(|state| { + let view_ix = state.asgn.asgn[col_ix]; + let weights = + &utils::given_weights(&[state], &[col_ix], &given)[0]; + + // combine the state weights with the given weights + let mut mm_weights: Vec = state.views[view_ix] + .weights + .iter() + .zip(weights[&view_ix].iter()) + .map(|(&w1, &w2)| w1 + w2) + .collect(); + + let z: f64 = logsumexp(&mm_weights); + mm_weights.iter_mut().for_each(|w| *w = (*w - z).exp()); + + state.views[view_ix].ftrs[&col_ix].to_mixture(mm_weights) + }) + .collect(); + + enum MType { + Gaussian, + Categorical, + Count, + Unsupported, + } + + let mtype = match mixture_types[0] { + MixtureType::Gaussian(_) => MType::Gaussian, + MixtureType::Poisson(_) => MType::Count, + MixtureType::Categorical(_) => MType::Categorical, + _ => MType::Unsupported, + }; + + match mtype { + MType::Gaussian => { + let mms: Vec<_> = mixture_types + .drain(..) + .map(|mt| { + if let MixtureType::Gaussian(mm) = mt { + mm + } else { + panic!("Expected Gaussian Mixture Type") + } + }) + .collect(); + let mm = Mixture::combine(mms); + Ok(Variability::Variance(mm.variance().unwrap())) + } + MType::Count => { + let mms: Vec<_> = mixture_types + .drain(..) + .map(|mt| { + if let MixtureType::Poisson(mm) = mt { + mm + } else { + panic!("Expected Poisson Mixture Type") + } + }) + .collect(); + let mm = Mixture::combine(mms); + Ok(Variability::Variance(mm.variance().unwrap())) + } + MType::Categorical => { + let mms: Vec<_> = mixture_types + .drain(..) + .map(|mt| { + if let MixtureType::Categorical(mm) = mt { + mm + } else { + panic!("Expected Categorical Mixture Type") + } + }) + .collect(); + let mm = Mixture::combine(mms); + Ok(Variability::Entropy(mm.entropy())) + } + _ => panic!("Unsupported MType"), + } + } + /// Compute the error between the observed data in a feature and the feature /// model. /// diff --git a/lace/src/prelude.rs b/lace/src/prelude.rs index acc597b7..e6ea31b2 100644 --- a/lace/src/prelude.rs +++ b/lace/src/prelude.rs @@ -6,6 +6,8 @@ pub use crate::{ RowSimilarityVariant, SupportExtension, Value, WriteMode, }; +pub use crate::interface::Variability; + pub use crate::data::DataSource; pub use lace_cc::{ diff --git a/pylace/Cargo.lock b/pylace/Cargo.lock index 04da8564..3892fd60 100644 --- a/pylace/Cargo.lock +++ b/pylace/Cargo.lock @@ -1562,9 +1562,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "rv" -version = "0.16.2" +version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ad4c180111696893363158f59ee1b15500947365270b7460a432fc207f926bf" +checksum = "35f602941aca67593b30eea71a0b372e50e3ad63e7aa6b98b2ea18ff74ba9cf8" dependencies = [ "doc-comment", "lru", diff --git a/pylace/README.md b/pylace/README.md index 970d8470..88c2b58b 100644 --- a/pylace/README.md +++ b/pylace/README.md @@ -44,3 +44,23 @@ engine.update(10_000) engine.predict('Class_of_Orbit', given={'Period_minutes': 1436.0}) # ('GEO', 0.13583714831550336) ``` + +## Tests + +To run tests, use `pytest` + +```console +$ pytest -x +``` + +To run doctets: + +```console +$ python tests/test_docs.py +``` + +To prevent plotly from displaying + +```console +$ LACE_DOCTEST_NOPLOT=1 python tests/test_docs.py +``` diff --git a/pylace/lace/analysis.py b/pylace/lace/analysis.py index 4d955660..4734b90f 100644 --- a/pylace/lace/analysis.py +++ b/pylace/lace/analysis.py @@ -302,7 +302,7 @@ def held_out_neglogp( │ ["Apogee_km"] ┆ 5.106627 ┆ 1 │ │ ["Apogee_km", "Eccentricity"] ┆ 2.951662 ┆ 2 │ │ ["Apogee_km", "Country_of_Operat… ┆ 2.951254 ┆ 3 │ - │ … ┆ … ┆ … │ + │ ["Apogee_km", "Country_of_Operat… ┆ 2.952801 ┆ 4 │ │ ["Apogee_km", "Country_of_Contra… ┆ 2.956224 ┆ 5 │ │ ["Apogee_km", "Country_of_Contra… ┆ 2.96479 ┆ 6 │ │ ["Apogee_km", "Country_of_Contra… ┆ 2.992173 ┆ 7 │ @@ -415,7 +415,7 @@ def held_out_inconsistency( │ ["Apogee_km"] ┆ 1.290609 ┆ 1 │ │ ["Apogee_km", "Eccentricity"] ┆ 0.74598 ┆ 2 │ │ ["Apogee_km", "Country_of_Operat… ┆ 0.745877 ┆ 3 │ - │ … ┆ … ┆ … │ + │ ["Apogee_km", "Country_of_Operat… ┆ 0.746268 ┆ 4 │ │ ["Apogee_km", "Country_of_Contra… ┆ 0.747133 ┆ 5 │ │ ["Apogee_km", "Country_of_Contra… ┆ 0.749297 ┆ 6 │ │ ["Apogee_km", "Country_of_Contra… ┆ 0.756218 ┆ 7 │ @@ -525,7 +525,7 @@ def held_out_uncertainty( │ ["Expected_Lifetime"] ┆ 0.437647 ┆ 1 │ │ ["Apogee_km", "Eccentricity"] ┆ 0.05561 ┆ 2 │ │ ["Apogee_km", "Country_of_Operat… ┆ 0.055283 ┆ 3 │ - │ … ┆ … ┆ … │ + │ ["Apogee_km", "Country_of_Operat… ┆ 0.056185 ┆ 4 │ │ ["Apogee_km", "Country_of_Operat… ┆ 0.057624 ┆ 5 │ │ ["Apogee_km", "Country_of_Contra… ┆ 0.0595 ┆ 6 │ │ ["Apogee_km", "Country_of_Contra… ┆ 0.077359 ┆ 7 │ @@ -945,15 +945,15 @@ def explain_prediction( │ --- ┆ --- │ │ str ┆ f64 │ ╞══════════════════════════════╪═════════════╡ - │ Country_of_Operator ┆ 3.5216e-16 │ - │ Users ┆ -3.1668e-14 │ - │ Purpose ┆ -9.5636e-14 │ - │ Class_of_Orbit ┆ -1.8263e-15 │ + │ Country_of_Operator ┆ 2.4617e-16 │ + │ Users ┆ -2.1412e-15 │ + │ Purpose ┆ -8.0193e-15 │ + │ Class_of_Orbit ┆ -2.2727e-15 │ │ … ┆ … │ - │ Launch_Site ┆ -2.8416e-15 │ - │ Launch_Vehicle ┆ 1.0704e-14 │ - │ Source_Used_for_Orbital_Data ┆ -3.9301e-15 │ - │ Inclination_radians ┆ -9.6259e-15 │ + │ Launch_Site ┆ -5.8214e-16 │ + │ Launch_Vehicle ┆ -9.6101e-16 │ + │ Source_Used_for_Orbital_Data ┆ -9.1997e-15 │ + │ Inclination_radians ┆ -1.5407e-15 │ └──────────────────────────────┴─────────────┘ Get the importances using the 'ablative-dist' method, which measures how @@ -975,7 +975,7 @@ def explain_prediction( │ Country_of_Operator ┆ -0.000109 │ │ Users ┆ 0.081289 │ │ Purpose ┆ 0.18938 │ - │ Class_of_Orbit ┆ 0.000133 │ + │ Class_of_Orbit ┆ 0.000119 │ │ … ┆ … │ │ Launch_Site ┆ 0.003411 │ │ Launch_Vehicle ┆ -0.018817 │ @@ -994,9 +994,3 @@ def explain_prediction( raise ValueError( f"Invalid method `{method}`, valid methods are {PRED_EXPLAIN_METHODS}" ) - - -if __name__ == "__main__": - import doctest - - doctest.testmod() diff --git a/pylace/lace/engine.py b/pylace/lace/engine.py index 88c64a02..ffee2d7f 100644 --- a/pylace/lace/engine.py +++ b/pylace/lace/engine.py @@ -1234,7 +1234,8 @@ def logp( ... xaxis_title='Period_minutes', ... yaxis_title='f(Period)', ... ) \ - ... .show() + ... .show() # doctest: +ELLIPSIS + {...} """ srs = ( self.engine.logp_scaled(values, given, state_ixs) @@ -1746,6 +1747,57 @@ def predict( """ return self.engine.predict(target, given, state_ixs, with_uncertainty) + def variability( + self, + target: Union[str, int], + given: Optional[Dict[Union[str, int], object]] = None, + state_ixs: Optional[List[int]] = None, + ): + """ + Return the variability of a conditional distribution. + + "Variability" is variance for target types with defined mean and + variance and is entropy otherwise. + + Parameters + ---------- + target: column index + The column for which to return the variability + given: Dict[column index, value], optional + Column -> Value dictionary describing observations. Note that + columns can either be indices (int) or names (str) + state_ixs: List[int], optional + An optional list specifying which states should be used in the + computation. If `None` (default), use all states. + + Returns + ------- + float + The variance or entropy (for categorical targets) + + Examples + -------- + Compute the variance of the Period_minutes column unconditioned + + >>> from lace.examples import Satellites + >>> sats = Satellites() + >>> sats.variability("Period_minutes") + 691324.3941953736 + + Compute the variance of Period_minutes for geosynchronous satellite + + >>> sats.variability("Period_minutes", given={"Class_of_Orbit": "GEO"}) + 136818.61181890886 + + Compute the entropy of Class_of_orbit + + >>> sats.variability("Class_of_Orbit") + 0.9362550555890782 + >>> sats.variability("Class_of_Orbit", given={"Period_minutes": 1440.0}) + 0.01569677151657056 + """ + return self.engine.variability(target, given, state_ixs) + def impute( self, col: Union[str, int], @@ -2143,7 +2195,7 @@ def pairwise_fn(self, fn_name, indices: Optional[list] = None, **kwargs): │ wolf ┆ rat ┆ 0.71689 │ │ wolf ┆ otter ┆ 0.492262 │ │ rat ┆ wolf ┆ 0.71689 │ - │ … ┆ … ┆ … │ + │ rat ┆ rat ┆ 1.0 │ │ rat ┆ otter ┆ 0.613095 │ │ otter ┆ wolf ┆ 0.492262 │ │ otter ┆ rat ┆ 0.613095 │ @@ -2167,7 +2219,7 @@ def pairwise_fn(self, fn_name, indices: Optional[list] = None, **kwargs): │ wolf ┆ rat ┆ 0.642647 │ │ wolf ┆ otter ┆ 0.302206 │ │ rat ┆ wolf ┆ 0.642647 │ - │ … ┆ … ┆ … │ + │ rat ┆ rat ┆ 1.0 │ │ rat ┆ otter ┆ 0.491176 │ │ otter ┆ wolf ┆ 0.302206 │ │ otter ┆ rat ┆ 0.491176 │ @@ -2251,7 +2303,8 @@ def clustermap( >>> animals = Animals() >>> animals.clustermap( ... "depprob", zmin=0, zmax=1, color_continuous_scale="greys" - ... ).figure.show() + ... ).figure.show() # doctest:+ELLIPSIS + {...} Use the ``fn_kwargs`` keyword argument to pass keyword arguments to the target function. @@ -2262,7 +2315,8 @@ def clustermap( ... zmax=1, ... color_continuous_scale="greys", ... fn_kwargs={"wrt": ["swims"]}, - ... ).figure.show() + ... ).figure.show() # doctest:+ELLIPSIS + {...} """ if fn_kwargs is None: fn_kwargs = {} @@ -2282,9 +2336,3 @@ def clustermap( return ClusterMap(df, linkage, fig) else: return ClusterMap(df, linkage) - - -if __name__ == "__main__": - import doctest - - doctest.testmod() diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index 0bc431e8..ef93e87d 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -42,7 +42,8 @@ def diagnostics( >>> from lace.examples import Satellites >>> from lace.plot import diagnostics - >>> diagnostics(Satellites(), log_x=True).show() + >>> diagnostics(Satellites(), log_x=True).show() # doctest:+ELLIPSIS + {...} """ diag = engine.diagnostics(name) step = np.arange(diag.shape[0]) @@ -122,7 +123,8 @@ def prediction_uncertainty( ... "Period_minutes", ... given={"Class_of_Orbit": "GEO"}, ... ) - >>> fig.show() + >>> fig.show() # doctest:+ELLIPSIS + {...} Narrow down the range for visualization @@ -134,7 +136,8 @@ def prediction_uncertainty( ... given={"Class_of_Orbit": "GEO"}, ... xs=pl.Series("Period_minutes", np.linspace(1350, 1500, 500)), ... ) - >>> fig.show() + >>> fig.show() # doctest:+ELLIPSIS + {...} Visualize uncertainty for a categorical target @@ -143,7 +146,8 @@ def prediction_uncertainty( ... "Class_of_Orbit", ... given={"Period_minutes": 1326.0}, ... ) - >>> fig.show() + >>> fig.show() # doctest:+ELLIPSIS + {...} """ pred, unc = engine.predict(target, given=given) @@ -579,7 +583,8 @@ def prediction_explanation( ... given, ... method='ablative-dist' ... ) - >>> fig.show() + >>> fig.show() # doctest:+ELLIPSIS + {...} """ if method is None: method = "ablative-err" @@ -622,9 +627,3 @@ def prediction_explanation( ) return srs, fig - - -if __name__ == "__main__": - import doctest - - doctest.testmod() diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index 5dacf03e..0c07c8e9 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -953,6 +953,22 @@ impl CoreEngine { } } + #[pyo3(signature=(target, given=None, state_ixs=None))] + fn variability( + &self, + target: &PyAny, + given: Option<&PyDict>, + state_ixs: Option>, + ) -> PyResult { + let col_ix = value_to_index(target, &self.col_indexer)?; + let given = dict_to_given(given, &self.engine, &self.col_indexer)?; + let val = self + .engine + .variability(col_ix, &given, state_ixs.as_deref()) + .map_err(|err| PyErr::new::(format!("{err}")))?; + Ok(val.into()) + } + /// Forward the Markov chains /// /// Parameters diff --git a/pylace/tests/test_docs.py b/pylace/tests/test_docs.py new file mode 100644 index 00000000..2a5b4d83 --- /dev/null +++ b/pylace/tests/test_docs.py @@ -0,0 +1,37 @@ +import doctest +import os +from importlib import import_module + +import polars as pl +from plotly import io + +NOPLOT = os.environ.get("LACE_DOCTEST_NOPLOT", "0") == "1" + + +if NOPLOT: + io.renderers.default = "json" + + +pl.Config(tbl_rows=8) + + +class _Context(dict): + def clear(self): + pass + + def copy(self): + return self + + +def runtest(mod): + module = import_module(mod) + extraglobs = _Context(module.__dict__.copy()) + doctest.testmod(module, extraglobs=extraglobs) + + +if __name__ == "__main__": + runtest("lace.engine") + runtest("lace.analysis") + + if not NOPLOT: + runtest("lace.plot")