Skip to content

Commit d7754bc

Browse files
authored
Merge pull request #26 from qutech/revert-25-personaltest
Revert unintended merge
2 parents b9c4046 + d00f647 commit d7754bc

11 files changed

+86
-8348
lines changed

qopt/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,3 @@
8282
__version__ = '1.3'
8383
__license__ = 'GNU GPLv3+'
8484
__author__ = 'Julian Teske, Forschungszentrum Juelich'
85-
86-
87-
try:
88-
from jax.config import config
89-
config.update("jax_enable_x64", True)
90-
#TODO: add new objects here/ import other stuff?
91-
# __all__ += []
92-
except ImportError:
93-
pass

qopt/amplitude_functions.py

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@
6464
"""
6565

6666
from abc import ABC, abstractmethod
67-
from typing import Callable, Optional
67+
from typing import Callable
6868

6969
import numpy as np
7070

71-
from typing import Union
7271

7372
class AmplitudeFunction(ABC):
7473
"""Abstract Base class of the amplitude function. """
@@ -219,125 +218,3 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
219218
# return: shape (time, func, par)
220219

221220
return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)
222-
223-
224-
###############################################################################
225-
226-
try:
227-
import jax.numpy as jnp
228-
from jax import jit,vmap,jacfwd
229-
_HAS_JAX = True
230-
except ImportError:
231-
from unittest import mock
232-
jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock()
233-
jnp = mock.Mock()
234-
_HAS_JAX = False
235-
236-
237-
class IdentityAmpFuncJAX(AmplitudeFunction):
238-
"""See docstring of class without JAX.
239-
Designed to return jax-numpy-arrays.
240-
"""
241-
242-
def __init__(self):
243-
if not _HAS_JAX:
244-
raise ImportError("JAX not available")
245-
246-
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
247-
"""See base class. """
248-
return jnp.asarray(x)
249-
250-
def derivative_by_chain_rule(
251-
self,
252-
deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray],
253-
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
254-
"""See base class. """
255-
return jnp.asarray(deriv_by_ctrl_amps)
256-
257-
258-
class UnaryAnalyticAmpFuncJAX(AmplitudeFunction):
259-
"""See docstring of class without JAX.
260-
Designed to return jax-numpy-arrays.
261-
Functions need to be compatible with jit.
262-
(Includes that functions need to be pure
263-
(i.e. output solely depends on input)).
264-
"""
265-
266-
def __init__(self,
267-
value_function: Callable[[float, ], float],
268-
derivative_function: [Callable[[float, ], float]]):
269-
if not _HAS_JAX:
270-
raise ImportError("JAX not available")
271-
self.value_function = jit(jnp.vectorize(value_function))
272-
self.derivative_function = jit(jnp.vectorize(derivative_function))
273-
274-
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
275-
"""See base class. """
276-
return jnp.asarray(self.value_function(x))
277-
278-
def derivative_by_chain_rule(
279-
self,
280-
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x):
281-
"""See base class. """
282-
du_by_dx = self.derivative_function(x)
283-
# du_by_dx shape: (n_time, n_ctrl)
284-
# deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl)
285-
# deriv_by_opt_par shape: (n_time, n_func, n_ctrl
286-
# since the function is unary we have n_ctrl = n_amps
287-
return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps)
288-
289-
290-
class CustomAmpFuncJAX(AmplitudeFunction):
291-
"""See docstring of class without JAX.
292-
Designed to return jax-numpy-arrays.
293-
Functions need to be compatible with jit.
294-
(Includes that functions need to be pure
295-
(i.e. output solely depends on input)).
296-
If derivative_function=None, autodiff is used.
297-
t_to_vectorize: if value_function/derivative_function not yet
298-
vectorized for num_t
299-
"""
300-
301-
def __init__(
302-
self,
303-
value_function: Callable[[Union[np.ndarray, jnp.ndarray],],
304-
Union[np.ndarray, jnp.ndarray]],
305-
derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],],
306-
Union[np.ndarray, jnp.ndarray]],
307-
t_to_vectorize: bool = False
308-
):
309-
if not _HAS_JAX:
310-
raise ImportError("JAX not available")
311-
if t_to_vectorize == True:
312-
self.value_function = jit(vmap(value_function),in_axes=(0,))
313-
else:
314-
self.value_function = jit(value_function)
315-
if derivative_function is not None:
316-
if t_to_vectorize == True:
317-
self.derivative_function = jit(vmap(derivative_function),in_axes=(0,))
318-
else:
319-
self.derivative_function = jit(derivative_function)
320-
else:
321-
if t_to_vectorize == True:
322-
def der_wrapper(x):
323-
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2)
324-
else:
325-
def der_wrapper(x):
326-
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2)
327-
self.derivative_function = jit(der_wrapper)
328-
329-
def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
330-
"""See base class. """
331-
return jnp.asarray(self.value_function(x))
332-
333-
def derivative_by_chain_rule(
334-
self,
335-
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray],
336-
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
337-
"""See base class. """
338-
du_by_dx = self.derivative_function(x)
339-
# du_by_dx: shape (time, par, ctrl)
340-
# deriv_by_ctrl_amps: shape (time, func, ctrl)
341-
# return: shape (time, func, par)
342-
343-
return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)

0 commit comments

Comments
 (0)