|
64 | 64 | """ |
65 | 65 |
|
66 | 66 | from abc import ABC, abstractmethod |
67 | | -from typing import Callable, Optional |
| 67 | +from typing import Callable |
68 | 68 |
|
69 | 69 | import numpy as np |
70 | 70 |
|
71 | | -from typing import Union |
72 | 71 |
|
73 | 72 | class AmplitudeFunction(ABC): |
74 | 73 | """Abstract Base class of the amplitude function. """ |
@@ -219,125 +218,3 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray, |
219 | 218 | # return: shape (time, func, par) |
220 | 219 |
|
221 | 220 | return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) |
222 | | - |
223 | | - |
224 | | -############################################################################### |
225 | | - |
226 | | -try: |
227 | | - import jax.numpy as jnp |
228 | | - from jax import jit,vmap,jacfwd |
229 | | - _HAS_JAX = True |
230 | | -except ImportError: |
231 | | - from unittest import mock |
232 | | - jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock() |
233 | | - jnp = mock.Mock() |
234 | | - _HAS_JAX = False |
235 | | - |
236 | | - |
237 | | -class IdentityAmpFuncJAX(AmplitudeFunction): |
238 | | - """See docstring of class without JAX. |
239 | | - Designed to return jax-numpy-arrays. |
240 | | - """ |
241 | | - |
242 | | - def __init__(self): |
243 | | - if not _HAS_JAX: |
244 | | - raise ImportError("JAX not available") |
245 | | - |
246 | | - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
247 | | - """See base class. """ |
248 | | - return jnp.asarray(x) |
249 | | - |
250 | | - def derivative_by_chain_rule( |
251 | | - self, |
252 | | - deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray], |
253 | | - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
254 | | - """See base class. """ |
255 | | - return jnp.asarray(deriv_by_ctrl_amps) |
256 | | - |
257 | | - |
258 | | -class UnaryAnalyticAmpFuncJAX(AmplitudeFunction): |
259 | | - """See docstring of class without JAX. |
260 | | - Designed to return jax-numpy-arrays. |
261 | | - Functions need to be compatible with jit. |
262 | | - (Includes that functions need to be pure |
263 | | - (i.e. output solely depends on input)). |
264 | | - """ |
265 | | - |
266 | | - def __init__(self, |
267 | | - value_function: Callable[[float, ], float], |
268 | | - derivative_function: [Callable[[float, ], float]]): |
269 | | - if not _HAS_JAX: |
270 | | - raise ImportError("JAX not available") |
271 | | - self.value_function = jit(jnp.vectorize(value_function)) |
272 | | - self.derivative_function = jit(jnp.vectorize(derivative_function)) |
273 | | - |
274 | | - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
275 | | - """See base class. """ |
276 | | - return jnp.asarray(self.value_function(x)) |
277 | | - |
278 | | - def derivative_by_chain_rule( |
279 | | - self, |
280 | | - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x): |
281 | | - """See base class. """ |
282 | | - du_by_dx = self.derivative_function(x) |
283 | | - # du_by_dx shape: (n_time, n_ctrl) |
284 | | - # deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl) |
285 | | - # deriv_by_opt_par shape: (n_time, n_func, n_ctrl |
286 | | - # since the function is unary we have n_ctrl = n_amps |
287 | | - return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps) |
288 | | - |
289 | | - |
290 | | -class CustomAmpFuncJAX(AmplitudeFunction): |
291 | | - """See docstring of class without JAX. |
292 | | - Designed to return jax-numpy-arrays. |
293 | | - Functions need to be compatible with jit. |
294 | | - (Includes that functions need to be pure |
295 | | - (i.e. output solely depends on input)). |
296 | | - If derivative_function=None, autodiff is used. |
297 | | - t_to_vectorize: if value_function/derivative_function not yet |
298 | | - vectorized for num_t |
299 | | - """ |
300 | | - |
301 | | - def __init__( |
302 | | - self, |
303 | | - value_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
304 | | - Union[np.ndarray, jnp.ndarray]], |
305 | | - derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
306 | | - Union[np.ndarray, jnp.ndarray]], |
307 | | - t_to_vectorize: bool = False |
308 | | - ): |
309 | | - if not _HAS_JAX: |
310 | | - raise ImportError("JAX not available") |
311 | | - if t_to_vectorize == True: |
312 | | - self.value_function = jit(vmap(value_function),in_axes=(0,)) |
313 | | - else: |
314 | | - self.value_function = jit(value_function) |
315 | | - if derivative_function is not None: |
316 | | - if t_to_vectorize == True: |
317 | | - self.derivative_function = jit(vmap(derivative_function),in_axes=(0,)) |
318 | | - else: |
319 | | - self.derivative_function = jit(derivative_function) |
320 | | - else: |
321 | | - if t_to_vectorize == True: |
322 | | - def der_wrapper(x): |
323 | | - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2) |
324 | | - else: |
325 | | - def der_wrapper(x): |
326 | | - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2) |
327 | | - self.derivative_function = jit(der_wrapper) |
328 | | - |
329 | | - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
330 | | - """See base class. """ |
331 | | - return jnp.asarray(self.value_function(x)) |
332 | | - |
333 | | - def derivative_by_chain_rule( |
334 | | - self, |
335 | | - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], |
336 | | - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
337 | | - """See base class. """ |
338 | | - du_by_dx = self.derivative_function(x) |
339 | | - # du_by_dx: shape (time, par, ctrl) |
340 | | - # deriv_by_ctrl_amps: shape (time, func, ctrl) |
341 | | - # return: shape (time, func, par) |
342 | | - |
343 | | - return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) |
0 commit comments