Skip to content

Commit b9c4046

Browse files
authored
Merge pull request #25
merge (undocumented) conveyor simulation changes
2 parents 7902a45 + 92cbcd5 commit b9c4046

11 files changed

+8361
-99
lines changed

qopt/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,12 @@
8282
__version__ = '1.3'
8383
__license__ = 'GNU GPLv3+'
8484
__author__ = 'Julian Teske, Forschungszentrum Juelich'
85+
86+
87+
try:
88+
from jax.config import config
89+
config.update("jax_enable_x64", True)
90+
#TODO: add new objects here/ import other stuff?
91+
# __all__ += []
92+
except ImportError:
93+
pass

qopt/amplitude_functions.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@
6464
"""
6565

6666
from abc import ABC, abstractmethod
67-
from typing import Callable
67+
from typing import Callable, Optional
6868

6969
import numpy as np
7070

71+
from typing import Union
7172

7273
class AmplitudeFunction(ABC):
7374
"""Abstract Base class of the amplitude function. """
@@ -218,3 +219,125 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
218219
# return: shape (time, func, par)
219220

220221
return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)
222+
223+
224+
###############################################################################
225+
226+
try:
227+
import jax.numpy as jnp
228+
from jax import jit,vmap,jacfwd
229+
_HAS_JAX = True
230+
except ImportError:
231+
from unittest import mock
232+
jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock()
233+
jnp = mock.Mock()
234+
_HAS_JAX = False
235+
236+
237+
class IdentityAmpFuncJAX(AmplitudeFunction):
238+
"""See docstring of class without JAX.
239+
Designed to return jax-numpy-arrays.
240+
"""
241+
242+
def __init__(self):
243+
if not _HAS_JAX:
244+
raise ImportError("JAX not available")
245+
246+
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
247+
"""See base class. """
248+
return jnp.asarray(x)
249+
250+
def derivative_by_chain_rule(
251+
self,
252+
deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray],
253+
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
254+
"""See base class. """
255+
return jnp.asarray(deriv_by_ctrl_amps)
256+
257+
258+
class UnaryAnalyticAmpFuncJAX(AmplitudeFunction):
259+
"""See docstring of class without JAX.
260+
Designed to return jax-numpy-arrays.
261+
Functions need to be compatible with jit.
262+
(Includes that functions need to be pure
263+
(i.e. output solely depends on input)).
264+
"""
265+
266+
def __init__(self,
267+
value_function: Callable[[float, ], float],
268+
derivative_function: [Callable[[float, ], float]]):
269+
if not _HAS_JAX:
270+
raise ImportError("JAX not available")
271+
self.value_function = jit(jnp.vectorize(value_function))
272+
self.derivative_function = jit(jnp.vectorize(derivative_function))
273+
274+
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
275+
"""See base class. """
276+
return jnp.asarray(self.value_function(x))
277+
278+
def derivative_by_chain_rule(
279+
self,
280+
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x):
281+
"""See base class. """
282+
du_by_dx = self.derivative_function(x)
283+
# du_by_dx shape: (n_time, n_ctrl)
284+
# deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl)
285+
# deriv_by_opt_par shape: (n_time, n_func, n_ctrl
286+
# since the function is unary we have n_ctrl = n_amps
287+
return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps)
288+
289+
290+
class CustomAmpFuncJAX(AmplitudeFunction):
291+
"""See docstring of class without JAX.
292+
Designed to return jax-numpy-arrays.
293+
Functions need to be compatible with jit.
294+
(Includes that functions need to be pure
295+
(i.e. output solely depends on input)).
296+
If derivative_function=None, autodiff is used.
297+
t_to_vectorize: if value_function/derivative_function not yet
298+
vectorized for num_t
299+
"""
300+
301+
def __init__(
302+
self,
303+
value_function: Callable[[Union[np.ndarray, jnp.ndarray],],
304+
Union[np.ndarray, jnp.ndarray]],
305+
derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],],
306+
Union[np.ndarray, jnp.ndarray]],
307+
t_to_vectorize: bool = False
308+
):
309+
if not _HAS_JAX:
310+
raise ImportError("JAX not available")
311+
if t_to_vectorize == True:
312+
self.value_function = jit(vmap(value_function),in_axes=(0,))
313+
else:
314+
self.value_function = jit(value_function)
315+
if derivative_function is not None:
316+
if t_to_vectorize == True:
317+
self.derivative_function = jit(vmap(derivative_function),in_axes=(0,))
318+
else:
319+
self.derivative_function = jit(derivative_function)
320+
else:
321+
if t_to_vectorize == True:
322+
def der_wrapper(x):
323+
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2)
324+
else:
325+
def der_wrapper(x):
326+
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2)
327+
self.derivative_function = jit(der_wrapper)
328+
329+
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
330+
"""See base class. """
331+
return jnp.asarray(self.value_function(x))
332+
333+
def derivative_by_chain_rule(
334+
self,
335+
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray],
336+
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
337+
"""See base class. """
338+
du_by_dx = self.derivative_function(x)
339+
# du_by_dx: shape (time, par, ctrl)
340+
# deriv_by_ctrl_amps: shape (time, func, ctrl)
341+
# return: shape (time, func, par)
342+
343+
return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)

0 commit comments

Comments
 (0)