Skip to content

Commit 00589f2

Browse files
committed
forward_hessian_nograd_sparse
1 parent ffd3802 commit 00589f2

File tree

2 files changed

+70
-15
lines changed

2 files changed

+70
-15
lines changed

python/finitediff-py/src/lib.rs

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ use numpy::{ndarray::Array1, IntoPyArray, PyArray1, PyArray2};
44
use pyo3::{
55
exceptions::PyTypeError,
66
prelude::*,
7-
types::{PyCFunction, PyDict, PyTuple},
7+
types::{PyCFunction, PyDict, PyList, PyTuple},
8+
PyNativeType, PyTypeInfo,
89
};
910

10-
fn process_args(args: &PyTuple, idx: usize) -> PyResult<Array1<f64>> {
11+
fn process_arg<T: PyTypeInfo + PyNativeType>(args: &PyTuple, idx: usize) -> PyResult<&T> {
1112
Ok(args
1213
.get_item(idx)
1314
.map_err(|_| PyErr::new::<PyTypeError, _>("Insufficient number of arguments"))?
14-
.downcast::<PyArray1<f64>>()?
15-
.to_owned_array())
15+
.downcast::<T>()?)
1616
}
1717

1818
macro_rules! not_callable {
@@ -37,7 +37,9 @@ fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
3737
let out = (ndarr::forward_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
3838
let x = PyArray1::from_array(py, x);
3939
Ok(f.call(py, (x,), None)?.extract(py)?)
40-
}))(&process_args(args, 0)?)?;
40+
}))(
41+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
42+
)?;
4143
Ok(out.into_pyarray(py).into())
4244
})
4345
},
@@ -60,7 +62,9 @@ fn central_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
6062
let out = (ndarr::central_diff(&|x: &Array1<f64>| -> Result<f64, Error> {
6163
let x = PyArray1::from_array(py, x);
6264
Ok(f.call(py, (x,), None)?.extract(py)?)
63-
}))(&process_args(args, 0)?)?;
65+
}))(
66+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
67+
)?;
6468
Ok(out.into_pyarray(py).into())
6569
})
6670
},
@@ -87,7 +91,9 @@ fn forward_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
8791
.extract::<&PyArray1<f64>>(py)?
8892
.to_owned_array())
8993
},
90-
))(&process_args(args, 0)?)?;
94+
))(
95+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
96+
)?;
9197
Ok(out.into_pyarray(py).into())
9298
})
9399
},
@@ -114,7 +120,9 @@ fn central_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
114120
.extract::<&PyArray1<f64>>(py)?
115121
.to_owned_array())
116122
},
117-
))(&process_args(args, 0)?)?;
123+
))(
124+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
125+
)?;
118126
Ok(out.into_pyarray(py).into())
119127
})
120128
},
@@ -143,7 +151,8 @@ fn forward_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
143151
.extract::<&PyArray1<f64>>(py)?
144152
.to_owned_array())
145153
}))(
146-
&process_args(args, 0)?, &process_args(args, 1)?
154+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
155+
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
147156
)?;
148157
Ok(out.into_pyarray(py).into())
149158
})
@@ -173,7 +182,8 @@ fn central_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
173182
.extract::<&PyArray1<f64>>(py)?
174183
.to_owned_array())
175184
}))(
176-
&process_args(args, 0)?, &process_args(args, 1)?
185+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
186+
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
177187
)?;
178188
Ok(out.into_pyarray(py).into())
179189
})
@@ -201,7 +211,9 @@ fn forward_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
201211
.extract::<&PyArray1<f64>>(py)?
202212
.to_owned_array())
203213
},
204-
))(&process_args(args, 0)?)?;
214+
))(
215+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
216+
)?;
205217
Ok(out.into_pyarray(py).into())
206218
})
207219
},
@@ -228,7 +240,9 @@ fn central_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
228240
.extract::<&PyArray1<f64>>(py)?
229241
.to_owned_array())
230242
},
231-
))(&process_args(args, 0)?)?;
243+
))(
244+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
245+
)?;
232246
Ok(out.into_pyarray(py).into())
233247
})
234248
},
@@ -257,7 +271,8 @@ fn forward_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
257271
.extract::<&PyArray1<f64>>(py)?
258272
.to_owned_array())
259273
}))(
260-
&process_args(args, 0)?, &process_args(args, 1)?
274+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
275+
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
261276
)?;
262277
Ok(out.into_pyarray(py).into())
263278
})
@@ -287,7 +302,8 @@ fn central_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
287302
.extract::<&PyArray1<f64>>(py)?
288303
.to_owned_array())
289304
}))(
290-
&process_args(args, 0)?, &process_args(args, 1)?
305+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
306+
&process_arg::<PyArray1<f64>>(args, 1)?.to_owned_array(),
291307
)?;
292308
Ok(out.into_pyarray(py).into())
293309
})
@@ -313,7 +329,37 @@ fn forward_hessian_nograd<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py P
313329
let x = PyArray1::from_array(py, x);
314330
Ok(f.call(py, (x,), None)?.extract::<f64>(py)?)
315331
},
316-
))(&process_args(args, 0)?)?;
332+
))(
333+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array()
334+
)?;
335+
Ok(out.into_pyarray(py).into())
336+
})
337+
},
338+
)
339+
} else {
340+
not_callable!(py, f)
341+
}
342+
}
343+
344+
/// Forward Hessian nograd sparse
345+
#[pyfunction]
346+
fn forward_hessian_nograd_sparse<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
347+
if f.as_ref(py).is_callable() {
348+
PyCFunction::new_closure(
349+
py,
350+
None,
351+
None,
352+
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
353+
Python::with_gil(|py| {
354+
let out = (ndarr::forward_hessian_nograd_sparse(
355+
&|x: &Array1<f64>| -> Result<f64, Error> {
356+
let x = PyArray1::from_array(py, x);
357+
Ok(f.call(py, (x,), None)?.extract::<f64>(py)?)
358+
},
359+
))(
360+
&process_arg::<PyArray1<f64>>(args, 0)?.to_owned_array(),
361+
process_arg::<PyList>(args, 1)?.extract()?,
362+
)?;
317363
Ok(out.into_pyarray(py).into())
318364
})
319365
},
@@ -337,5 +383,6 @@ fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
337383
m.add_function(wrap_pyfunction!(forward_hessian_vec_prod, m)?)?;
338384
m.add_function(wrap_pyfunction!(central_hessian_vec_prod, m)?)?;
339385
m.add_function(wrap_pyfunction!(forward_hessian_nograd, m)?)?;
386+
m.add_function(wrap_pyfunction!(forward_hessian_nograd_sparse, m)?)?;
340387
Ok(())
341388
}

python/finitediff-py/test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
forward_hessian_vec_prod,
1111
central_hessian_vec_prod,
1212
forward_hessian_nograd,
13+
forward_hessian_nograd_sparse,
1314
)
1415
import numpy as np
1516

@@ -108,6 +109,13 @@ def g(x):
108109
h = forward_hessian_nograd(f)
109110
x = np.array([1.0, 1.0, 1.0, 1.0])
110111
print(h(x))
112+
113+
h = forward_hessian_nograd_sparse(f)
114+
x = np.array([1.0, 1.0, 1.0, 1.0])
115+
indices = [[1, 1], [2, 3], [3, 3]]
116+
print(h(x, indices))
117+
118+
111119
# class NotCallable:
112120
# pass
113121

0 commit comments

Comments
 (0)