|
64 | 64 | """
|
65 | 65 |
|
66 | 66 | from abc import ABC, abstractmethod
|
67 |
| -from typing import Callable |
| 67 | +from typing import Callable, Optional |
68 | 68 |
|
69 | 69 | import numpy as np
|
70 | 70 |
|
| 71 | +from typing import Union |
71 | 72 |
|
72 | 73 | class AmplitudeFunction(ABC):
|
73 | 74 | """Abstract Base class of the amplitude function. """
|
@@ -218,3 +219,125 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
|
218 | 219 | # return: shape (time, func, par)
|
219 | 220 |
|
220 | 221 | return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)
|
| 222 | + |
| 223 | + |
| 224 | +############################################################################### |
| 225 | + |
| 226 | +try: |
| 227 | + import jax.numpy as jnp |
| 228 | + from jax import jit,vmap,jacfwd |
| 229 | + _HAS_JAX = True |
| 230 | +except ImportError: |
| 231 | + from unittest import mock |
| 232 | + jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock() |
| 233 | + jnp = mock.Mock() |
| 234 | + _HAS_JAX = False |
| 235 | + |
| 236 | + |
| 237 | +class IdentityAmpFuncJAX(AmplitudeFunction): |
| 238 | + """See docstring of class without JAX. |
| 239 | + Designed to return jax-numpy-arrays. |
| 240 | + """ |
| 241 | + |
| 242 | + def __init__(self): |
| 243 | + if not _HAS_JAX: |
| 244 | + raise ImportError("JAX not available") |
| 245 | + |
| 246 | + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
| 247 | + """See base class. """ |
| 248 | + return jnp.asarray(x) |
| 249 | + |
| 250 | + def derivative_by_chain_rule( |
| 251 | + self, |
| 252 | + deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray], |
| 253 | + x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
| 254 | + """See base class. """ |
| 255 | + return jnp.asarray(deriv_by_ctrl_amps) |
| 256 | + |
| 257 | + |
| 258 | +class UnaryAnalyticAmpFuncJAX(AmplitudeFunction): |
| 259 | + """See docstring of class without JAX. |
| 260 | + Designed to return jax-numpy-arrays. |
| 261 | + Functions need to be compatible with jit. |
| 262 | + (Includes that functions need to be pure |
| 263 | + (i.e. output solely depends on input)). |
| 264 | + """ |
| 265 | + |
| 266 | + def __init__(self, |
| 267 | + value_function: Callable[[float, ], float], |
| 268 | + derivative_function: [Callable[[float, ], float]]): |
| 269 | + if not _HAS_JAX: |
| 270 | + raise ImportError("JAX not available") |
| 271 | + self.value_function = jit(jnp.vectorize(value_function)) |
| 272 | + self.derivative_function = jit(jnp.vectorize(derivative_function)) |
| 273 | + |
| 274 | + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
| 275 | + """See base class. """ |
| 276 | + return jnp.asarray(self.value_function(x)) |
| 277 | + |
| 278 | + def derivative_by_chain_rule( |
| 279 | + self, |
| 280 | + deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x): |
| 281 | + """See base class. """ |
| 282 | + du_by_dx = self.derivative_function(x) |
| 283 | + # du_by_dx shape: (n_time, n_ctrl) |
| 284 | + # deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl) |
| 285 | + # deriv_by_opt_par shape: (n_time, n_func, n_ctrl |
| 286 | + # since the function is unary we have n_ctrl = n_amps |
| 287 | + return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps) |
| 288 | + |
| 289 | + |
| 290 | +class CustomAmpFuncJAX(AmplitudeFunction): |
| 291 | + """See docstring of class without JAX. |
| 292 | + Designed to return jax-numpy-arrays. |
| 293 | + Functions need to be compatible with jit. |
| 294 | + (Includes that functions need to be pure |
| 295 | + (i.e. output solely depends on input)). |
| 296 | + If derivative_function=None, autodiff is used. |
| 297 | + t_to_vectorize: if value_function/derivative_function not yet |
| 298 | + vectorized for num_t |
| 299 | + """ |
| 300 | + |
| 301 | + def __init__( |
| 302 | + self, |
| 303 | + value_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
| 304 | + Union[np.ndarray, jnp.ndarray]], |
| 305 | + derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],], |
| 306 | + Union[np.ndarray, jnp.ndarray]], |
| 307 | + t_to_vectorize: bool = False |
| 308 | + ): |
| 309 | + if not _HAS_JAX: |
| 310 | + raise ImportError("JAX not available") |
| 311 | + if t_to_vectorize == True: |
| 312 | + self.value_function = jit(vmap(value_function),in_axes=(0,)) |
| 313 | + else: |
| 314 | + self.value_function = jit(value_function) |
| 315 | + if derivative_function is not None: |
| 316 | + if t_to_vectorize == True: |
| 317 | + self.derivative_function = jit(vmap(derivative_function),in_axes=(0,)) |
| 318 | + else: |
| 319 | + self.derivative_function = jit(derivative_function) |
| 320 | + else: |
| 321 | + if t_to_vectorize == True: |
| 322 | + def der_wrapper(x): |
| 323 | + return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2) |
| 324 | + else: |
| 325 | + def der_wrapper(x): |
| 326 | + return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2) |
| 327 | + self.derivative_function = jit(der_wrapper) |
| 328 | + |
| 329 | + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
| 330 | + """See base class. """ |
| 331 | + return jnp.asarray(self.value_function(x)) |
| 332 | + |
| 333 | + def derivative_by_chain_rule( |
| 334 | + self, |
| 335 | + deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], |
| 336 | + x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: |
| 337 | + """See base class. """ |
| 338 | + du_by_dx = self.derivative_function(x) |
| 339 | + # du_by_dx: shape (time, par, ctrl) |
| 340 | + # deriv_by_ctrl_amps: shape (time, func, ctrl) |
| 341 | + # return: shape (time, func, par) |
| 342 | + |
| 343 | + return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) |
0 commit comments