|
64 | 64 | """
|
65 | 65 |
|
66 | 66 | from abc import ABC, abstractmethod
|
67 |
| -from typing import Callable, Optional |
| 67 | +from typing import Callable |
68 | 68 |
|
69 | 69 | import numpy as np
|
70 | 70 |
|
71 |
| -from typing import Union |
72 | 71 |
|
73 | 72 | class AmplitudeFunction(ABC):
|
74 | 73 | """Abstract Base class of the amplitude function. """
|
@@ -219,125 +218,3 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
|
219 | 218 | # return: shape (time, func, par)
|
220 | 219 |
|
221 | 220 | return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)
|
222 |
| - |
223 |
| - |
224 |
| -############################################################################### |
225 |
| - |
226 |
| -try: |
227 |
| - import jax.numpy as jnp |
228 |
| - from jax import jit,vmap,jacfwd |
229 |
| - _HAS_JAX = True |
230 |
| -except ImportError: |
231 |
| - from unittest import mock |
232 |
| - jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock() |
233 |
| - jnp = mock.Mock() |
234 |
| - _HAS_JAX = False |
235 |
| - |
236 |
| - |
237 |
| -class IdentityAmpFuncJAX(AmplitudeFunction): |
238 |
| - """See docstring of class without JAX. |
239 |
| - Designed to return jax-numpy-arrays. |
240 |
| - """ |
241 |
| - |
242 |
| - def __init__(self): |
243 |
| - if not _HAS_JAX: |
244 |
| - raise ImportError("JAX not available") |
245 |
| - |
246 |
| - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
247 |
| - """See base class. """ |
248 |
| - return jnp.asarray(x) |
249 |
| - |
250 |
| - def derivative_by_chain_rule( |
251 |
| - self, |
252 |
| - deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray], |
253 |
| - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
254 |
| - """See base class. """ |
255 |
| - return jnp.asarray(deriv_by_ctrl_amps) |
256 |
| - |
257 |
| - |
258 |
| -class UnaryAnalyticAmpFuncJAX(AmplitudeFunction): |
259 |
| - """See docstring of class without JAX. |
260 |
| - Designed to return jax-numpy-arrays. |
261 |
| - Functions need to be compatible with jit. |
262 |
| - (Includes that functions need to be pure |
263 |
| - (i.e. output solely depends on input)). |
264 |
| - """ |
265 |
| - |
266 |
| - def __init__(self, |
267 |
| - value_function: Callable[[float, ], float], |
268 |
| - derivative_function: [Callable[[float, ], float]]): |
269 |
| - if not _HAS_JAX: |
270 |
| - raise ImportError("JAX not available") |
271 |
| - self.value_function = jit(jnp.vectorize(value_function)) |
272 |
| - self.derivative_function = jit(jnp.vectorize(derivative_function)) |
273 |
| - |
274 |
| - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
275 |
| - """See base class. """ |
276 |
| - return jnp.asarray(self.value_function(x)) |
277 |
| - |
278 |
| - def derivative_by_chain_rule( |
279 |
| - self, |
280 |
| - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x): |
281 |
| - """See base class. """ |
282 |
| - du_by_dx = self.derivative_function(x) |
283 |
| - # du_by_dx shape: (n_time, n_ctrl) |
284 |
| - # deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl) |
285 |
| - # deriv_by_opt_par shape: (n_time, n_func, n_ctrl |
286 |
| - # since the function is unary we have n_ctrl = n_amps |
287 |
| - return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps) |
288 |
| - |
289 |
| - |
290 |
| -class CustomAmpFuncJAX(AmplitudeFunction): |
291 |
| - """See docstring of class without JAX. |
292 |
| - Designed to return jax-numpy-arrays. |
293 |
| - Functions need to be compatible with jit. |
294 |
| - (Includes that functions need to be pure |
295 |
| - (i.e. output solely depends on input)). |
296 |
| - If derivative_function=None, autodiff is used. |
297 |
| - t_to_vectorize: if value_function/derivative_function not yet |
298 |
| - vectorized for num_t |
299 |
| - """ |
300 |
| - |
301 |
| - def __init__( |
302 |
| - self, |
303 |
| - value_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
304 |
| - Union[np.ndarray, jnp.ndarray]], |
305 |
| - derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
306 |
| - Union[np.ndarray, jnp.ndarray]], |
307 |
| - t_to_vectorize: bool = False |
308 |
| - ): |
309 |
| - if not _HAS_JAX: |
310 |
| - raise ImportError("JAX not available") |
311 |
| - if t_to_vectorize == True: |
312 |
| - self.value_function = jit(vmap(value_function),in_axes=(0,)) |
313 |
| - else: |
314 |
| - self.value_function = jit(value_function) |
315 |
| - if derivative_function is not None: |
316 |
| - if t_to_vectorize == True: |
317 |
| - self.derivative_function = jit(vmap(derivative_function),in_axes=(0,)) |
318 |
| - else: |
319 |
| - self.derivative_function = jit(derivative_function) |
320 |
| - else: |
321 |
| - if t_to_vectorize == True: |
322 |
| - def der_wrapper(x): |
323 |
| - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2) |
324 |
| - else: |
325 |
| - def der_wrapper(x): |
326 |
| - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2) |
327 |
| - self.derivative_function = jit(der_wrapper) |
328 |
| - |
329 |
| - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
330 |
| - """See base class. """ |
331 |
| - return jnp.asarray(self.value_function(x)) |
332 |
| - |
333 |
| - def derivative_by_chain_rule( |
334 |
| - self, |
335 |
| - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], |
336 |
| - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
337 |
| - """See base class. """ |
338 |
| - du_by_dx = self.derivative_function(x) |
339 |
| - # du_by_dx: shape (time, par, ctrl) |
340 |
| - # deriv_by_ctrl_amps: shape (time, func, ctrl) |
341 |
| - # return: shape (time, func, par) |
342 |
| - |
343 |
| - return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) |
0 commit comments