Skip to content

Commit

Permalink
forward_hessian_nograd_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 9, 2024
1 parent ffd3802 commit 00589f2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 15 deletions.
77 changes: 62 additions & 15 deletions python/finitediff-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ use numpy::{ndarray::Array1, IntoPyArray, PyArray1, PyArray2};
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{PyCFunction, PyDict, PyTuple},
types::{PyCFunction, PyDict, PyList, PyTuple},
PyNativeType, PyTypeInfo,
};

fn process_args(args: &PyTuple, idx: usize) -> PyResult<Array1<f64>> {
fn process_arg<T: PyTypeInfo + PyNativeType>(args: &PyTuple, idx: usize) -> PyResult<&T> {
Ok(args
.get_item(idx)
.map_err(|_| PyErr::new::<PyTypeError, _>("Insufficient number of arguments"))?
.downcast::<PyArray1<f64>>()?
.to_owned_array())
.downcast::<T>()?)
}

macro_rules! not_callable {
Expand All @@ -37,7 +37,9 @@ fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
let out = (ndarr::forward_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?.extract(py)?)
}))(&process_args(args, 0)?)?;
}))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand All @@ -60,7 +62,9 @@ fn central_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
let out = (ndarr::central_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?.extract(py)?)
}))(&process_args(args, 0)?)?;
}))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand All @@ -87,7 +91,9 @@ fn forward_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand All @@ -114,7 +120,9 @@ fn central_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand Down Expand Up @@ -143,7 +151,8 @@ fn forward_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
)?;
Ok(out.into_pyarray(py).into())
})
Expand Down Expand Up @@ -173,7 +182,8 @@ fn central_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
)?;
Ok(out.into_pyarray(py).into())
})
Expand Down Expand Up @@ -201,7 +211,9 @@ fn forward_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand All @@ -228,7 +240,9 @@ fn central_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand Down Expand Up @@ -257,7 +271,8 @@ fn forward_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
)?;
Ok(out.into_pyarray(py).into())
})
Expand Down Expand Up @@ -287,7 +302,8 @@ fn central_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
)?;
Ok(out.into_pyarray(py).into())
})
Expand All @@ -313,7 +329,37 @@ fn forward_hessian_nograd<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py P
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?.extract::<f64>(py)?)
},
))(&process_args(args, 0)?)?;
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// Forward Hessian nograd sparse
#[pyfunction]
fn forward_hessian_nograd_sparse<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_hessian_nograd_sparse(
&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?.extract::<f64>(py)?)
},
))(
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
process_arg::<PyList>(args, 1)?.extract()?,
)?;
Ok(out.into_pyarray(py).into())
})
},
Expand All @@ -337,5 +383,6 @@ fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(forward_hessian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(central_hessian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(forward_hessian_nograd, m)?)?;
m.add_function(wrap_pyfunction!(forward_hessian_nograd_sparse, m)?)?;
Ok(())
}
8 changes: 8 additions & 0 deletions python/finitediff-py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
forward_hessian_vec_prod,
central_hessian_vec_prod,
forward_hessian_nograd,
forward_hessian_nograd_sparse,
)
import numpy as np

Expand Down Expand Up @@ -108,6 +109,13 @@ def g(x):
h = forward_hessian_nograd(f)
x = np.array([1.0, 1.0, 1.0, 1.0])
print(h(x))

h = forward_hessian_nograd_sparse(f)
x = np.array([1.0, 1.0, 1.0, 1.0])
indices = [[1, 1], [2, 3], [3, 3]]
print(h(x, indices))


# class NotCallable:
# pass

Expand Down

0 comments on commit 00589f2

Please sign in to comment.