From f8200170af125c15a960e19e2ebe01d91469034d Mon Sep 17 00:00:00 2001 From: Michael T Schmidt Date: Tue, 13 Feb 2024 10:35:56 -0700 Subject: [PATCH] feat(python): Added support for pickling lace.Engine (#184) * feat(python): Added support for pickling lace.Engine * fix(pylace): Removed unnecessary expect. --- pylace/Cargo.lock | 10 +++++---- pylace/Cargo.toml | 2 ++ pylace/src/lib.rs | 43 +++++++++++++++++++++++++++++++++---- pylace/src/utils.rs | 3 ++- pylace/tests/test_pickle.py | 14 ++++++++++++ 5 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 pylace/tests/test_pickle.py diff --git a/pylace/Cargo.lock b/pylace/Cargo.lock index 2f8e8ed9..c0154326 100644 --- a/pylace/Cargo.lock +++ b/pylace/Cargo.lock @@ -1347,6 +1347,7 @@ checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" name = "pylace" version = "0.7.0" dependencies = [ + "bincode", "lace", "lace_utils", "polars", @@ -1354,6 +1355,7 @@ dependencies = [ "pyo3", "rand", "rand_xoshiro", + "serde", "serde_json", "serde_yaml", ] @@ -1601,18 +1603,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", diff --git a/pylace/Cargo.toml b/pylace/Cargo.toml index cf79975c..cd49a908 100644 --- a/pylace/Cargo.toml +++ b/pylace/Cargo.toml @@ -18,6 +18,8 @@ serde_json = "1.0.91" serde_yaml = "0.9.17" polars = "0.36" polars-arrow = "0.36.2" +serde = { version = "1.0.196", features = ["derive"] } +bincode = "1.3.3" [package.metadata.maturin] name = "lace.core" diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index fe65a227..3b010c6e 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -15,19 +15,21 @@ use lace::metadata::SerializedType; use lace::prelude::ColMetadataList; use lace::{EngineUpdateConfig, FType, HasStates, OracleT}; use polars::prelude::{DataFrame, NamedFrom, Series}; +use pyo3::create_exception; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; -use pyo3::types::{PyDict, PyList, PyType}; -use pyo3::{create_exception, prelude::*}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList, PyType}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; +use serde::{Deserialize, Serialize}; use metadata::{Codebook, CodebookBuilder}; use crate::update_handler::PyUpdateHandler; use crate::utils::*; -#[derive(Clone)] -#[pyclass(subclass)] +#[derive(Clone, Serialize, Deserialize)] +#[pyclass(subclass, module = "lace.core")] struct CoreEngine { engine: lace::Engine, col_indexer: Indexer, @@ -1320,6 +1322,39 @@ impl CoreEngine { }) }) } + + pub fn __setstate__( + &mut self, + py: Python, + state: PyObject, + ) -> PyResult<()> { + let s = state.extract::<&PyBytes>(py)?; + *self = bincode::deserialize(s.as_bytes()).map_err(|e| { + PyValueError::new_err(format!("Cannot Deserialize CoreEngine: {e}")) + })?; + Ok(()) + } + + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new( + py, + &bincode::serialize(&self).map_err(|e| { + PyValueError::new_err(format!( + "Cannot Serialize CoreEngine: {e}" + )) + })?, + ) + .to_object(py)) + } + + pub fn __getnewargs__(&self) -> PyResult<(PyDataFrame,)> { + Ok((PyDataFrame( + polars::df! { + "ID" => [0], + } + .map_err(|e| PyValueError::new_err(format!("Polars error: {e}")))?, + ),)) + } } #[pyfunction] diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 55e3bc1f..033dcd94 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -15,6 +15,7 @@ use pyo3::prelude::*; use pyo3::types::{ PyAny, PyBool, PyDict, PyInt, PyList, PySlice, PyString, PyTuple, }; +use serde::{Deserialize, Serialize}; use crate::df::{PyDataFrame, PySeries}; @@ -381,7 +382,7 @@ pub(crate) fn str_to_mitype(mi_type: &str) -> PyResult { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct Indexer { pub to_ix: HashMap, pub to_name: HashMap, diff --git a/pylace/tests/test_pickle.py b/pylace/tests/test_pickle.py new file mode 100644 index 00000000..b23d941f --- /dev/null +++ b/pylace/tests/test_pickle.py @@ -0,0 +1,14 @@ +import pickle + +from lace import examples + + +def test_pickle_engine(): + engine = examples.Animals().engine + s = pickle.dumps(engine) + engine_b = pickle.loads(s) + + sim_a = engine.simulate(["swims", "flys"], n=10) + sim_b = engine_b.simulate(["swims", "flys"], n=10) + + assert sim_a.equals(sim_b)