@@ -4,15 +4,15 @@ use numpy::{ndarray::Array1, IntoPyArray, PyArray1, PyArray2};
4
4
use pyo3:: {
5
5
exceptions:: PyTypeError ,
6
6
prelude:: * ,
7
- types:: { PyCFunction , PyDict , PyTuple } ,
7
+ types:: { PyCFunction , PyDict , PyList , PyTuple } ,
8
+ PyNativeType , PyTypeInfo ,
8
9
} ;
9
10
10
- fn process_args ( args : & PyTuple , idx : usize ) -> PyResult < Array1 < f64 > > {
11
+ fn process_arg < T : PyTypeInfo + PyNativeType > ( args : & PyTuple , idx : usize ) -> PyResult < & T > {
11
12
Ok ( args
12
13
. get_item ( idx)
13
14
. map_err ( |_| PyErr :: new :: < PyTypeError , _ > ( "Insufficient number of arguments" ) ) ?
14
- . downcast :: < PyArray1 < f64 > > ( ) ?
15
- . to_owned_array ( ) )
15
+ . downcast :: < T > ( ) ?)
16
16
}
17
17
18
18
macro_rules! not_callable {
@@ -37,7 +37,9 @@ fn forward_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
37
37
let out = ( ndarr:: forward_diff ( & |x : & Array1 < f64 > | -> Result < f64 , Error > {
38
38
let x = PyArray1 :: from_array ( py, x) ;
39
39
Ok ( f. call ( py, ( x, ) , None ) ?. extract ( py) ?)
40
- } ) ) ( & process_args ( args, 0 ) ?) ?;
40
+ } ) ) (
41
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
42
+ ) ?;
41
43
Ok ( out. into_pyarray ( py) . into ( ) )
42
44
} )
43
45
} ,
@@ -60,7 +62,9 @@ fn central_diff<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction
60
62
let out = ( ndarr:: central_diff ( & |x : & Array1 < f64 > | -> Result < f64 , Error > {
61
63
let x = PyArray1 :: from_array ( py, x) ;
62
64
Ok ( f. call ( py, ( x, ) , None ) ?. extract ( py) ?)
63
- } ) ) ( & process_args ( args, 0 ) ?) ?;
65
+ } ) ) (
66
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
67
+ ) ?;
64
68
Ok ( out. into_pyarray ( py) . into ( ) )
65
69
} )
66
70
} ,
@@ -87,7 +91,9 @@ fn forward_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
87
91
. extract :: < & PyArray1 < f64 > > ( py) ?
88
92
. to_owned_array ( ) )
89
93
} ,
90
- ) ) ( & process_args ( args, 0 ) ?) ?;
94
+ ) ) (
95
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
96
+ ) ?;
91
97
Ok ( out. into_pyarray ( py) . into ( ) )
92
98
} )
93
99
} ,
@@ -114,7 +120,9 @@ fn central_jacobian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunc
114
120
. extract :: < & PyArray1 < f64 > > ( py) ?
115
121
. to_owned_array ( ) )
116
122
} ,
117
- ) ) ( & process_args ( args, 0 ) ?) ?;
123
+ ) ) (
124
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
125
+ ) ?;
118
126
Ok ( out. into_pyarray ( py) . into ( ) )
119
127
} )
120
128
} ,
@@ -143,7 +151,8 @@ fn forward_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
143
151
. extract :: < & PyArray1 < f64 > > ( py) ?
144
152
. to_owned_array ( ) )
145
153
} ) ) (
146
- & process_args ( args, 0 ) ?, & process_args ( args, 1 ) ?
154
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( ) ,
155
+ & process_arg :: < PyArray1 < f64 > > ( args, 1 ) ?. to_owned_array ( ) ,
147
156
) ?;
148
157
Ok ( out. into_pyarray ( py) . into ( ) )
149
158
} )
@@ -173,7 +182,8 @@ fn central_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
173
182
. extract :: < & PyArray1 < f64 > > ( py) ?
174
183
. to_owned_array ( ) )
175
184
} ) ) (
176
- & process_args ( args, 0 ) ?, & process_args ( args, 1 ) ?
185
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( ) ,
186
+ & process_arg :: < PyArray1 < f64 > > ( args, 1 ) ?. to_owned_array ( ) ,
177
187
) ?;
178
188
Ok ( out. into_pyarray ( py) . into ( ) )
179
189
} )
@@ -201,7 +211,9 @@ fn forward_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
201
211
. extract :: < & PyArray1 < f64 > > ( py) ?
202
212
. to_owned_array ( ) )
203
213
} ,
204
- ) ) ( & process_args ( args, 0 ) ?) ?;
214
+ ) ) (
215
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
216
+ ) ?;
205
217
Ok ( out. into_pyarray ( py) . into ( ) )
206
218
} )
207
219
} ,
@@ -228,7 +240,9 @@ fn central_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunct
228
240
. extract :: < & PyArray1 < f64 > > ( py) ?
229
241
. to_owned_array ( ) )
230
242
} ,
231
- ) ) ( & process_args ( args, 0 ) ?) ?;
243
+ ) ) (
244
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
245
+ ) ?;
232
246
Ok ( out. into_pyarray ( py) . into ( ) )
233
247
} )
234
248
} ,
@@ -257,7 +271,8 @@ fn forward_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
257
271
. extract :: < & PyArray1 < f64 > > ( py) ?
258
272
. to_owned_array ( ) )
259
273
} ) ) (
260
- & process_args ( args, 0 ) ?, & process_args ( args, 1 ) ?
274
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( ) ,
275
+ & process_arg :: < PyArray1 < f64 > > ( args, 1 ) ?. to_owned_array ( ) ,
261
276
) ?;
262
277
Ok ( out. into_pyarray ( py) . into ( ) )
263
278
} )
@@ -287,7 +302,8 @@ fn central_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py
287
302
. extract :: < & PyArray1 < f64 > > ( py) ?
288
303
. to_owned_array ( ) )
289
304
} ) ) (
290
- & process_args ( args, 0 ) ?, & process_args ( args, 1 ) ?
305
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( ) ,
306
+ & process_arg :: < PyArray1 < f64 > > ( args, 1 ) ?. to_owned_array ( ) ,
291
307
) ?;
292
308
Ok ( out. into_pyarray ( py) . into ( ) )
293
309
} )
@@ -313,7 +329,37 @@ fn forward_hessian_nograd<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py P
313
329
let x = PyArray1 :: from_array ( py, x) ;
314
330
Ok ( f. call ( py, ( x, ) , None ) ?. extract :: < f64 > ( py) ?)
315
331
} ,
316
- ) ) ( & process_args ( args, 0 ) ?) ?;
332
+ ) ) (
333
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( )
334
+ ) ?;
335
+ Ok ( out. into_pyarray ( py) . into ( ) )
336
+ } )
337
+ } ,
338
+ )
339
+ } else {
340
+ not_callable ! ( py, f)
341
+ }
342
+ }
343
+
344
+ /// Forward Hessian nograd sparse
345
+ #[ pyfunction]
346
+ fn forward_hessian_nograd_sparse < ' py > ( py : Python < ' py > , f : Py < PyAny > ) -> PyResult < & ' py PyCFunction > {
347
+ if f. as_ref ( py) . is_callable ( ) {
348
+ PyCFunction :: new_closure (
349
+ py,
350
+ None ,
351
+ None ,
352
+ move |args : & PyTuple , _kwargs : Option < & PyDict > | -> PyResult < Py < PyArray2 < f64 > > > {
353
+ Python :: with_gil ( |py| {
354
+ let out = ( ndarr:: forward_hessian_nograd_sparse (
355
+ & |x : & Array1 < f64 > | -> Result < f64 , Error > {
356
+ let x = PyArray1 :: from_array ( py, x) ;
357
+ Ok ( f. call ( py, ( x, ) , None ) ?. extract :: < f64 > ( py) ?)
358
+ } ,
359
+ ) ) (
360
+ & process_arg :: < PyArray1 < f64 > > ( args, 0 ) ?. to_owned_array ( ) ,
361
+ process_arg :: < PyList > ( args, 1 ) ?. extract ( ) ?,
362
+ ) ?;
317
363
Ok ( out. into_pyarray ( py) . into ( ) )
318
364
} )
319
365
} ,
@@ -337,5 +383,6 @@ fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
337
383
m. add_function ( wrap_pyfunction ! ( forward_hessian_vec_prod, m) ?) ?;
338
384
m. add_function ( wrap_pyfunction ! ( central_hessian_vec_prod, m) ?) ?;
339
385
m. add_function ( wrap_pyfunction ! ( forward_hessian_nograd, m) ?) ?;
386
+ m. add_function ( wrap_pyfunction ! ( forward_hessian_nograd_sparse, m) ?) ?;
340
387
Ok ( ( ) )
341
388
}
0 commit comments