Skip to content

Revert "merge (undocumented) conveyor simulation changes" #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions qopt/__init__.py
Original file line number Diff line number Diff line change
@@ -82,12 +82,3 @@
__version__ = '1.3'
__license__ = 'GNU GPLv3+'
__author__ = 'Julian Teske, Forschungszentrum Juelich'


try:
from jax.config import config
config.update("jax_enable_x64", True)
#TODO: add new objects here/ import other stuff?
# __all__ += []
except ImportError:
pass
125 changes: 1 addition & 124 deletions qopt/amplitude_functions.py
Original file line number Diff line number Diff line change
@@ -64,11 +64,10 @@
"""

from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable

import numpy as np

from typing import Union

class AmplitudeFunction(ABC):
"""Abstract Base class of the amplitude function. """
@@ -219,125 +218,3 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
# return: shape (time, func, par)

return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)


###############################################################################

try:
import jax.numpy as jnp
from jax import jit,vmap,jacfwd
_HAS_JAX = True
except ImportError:
from unittest import mock
jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock()
jnp = mock.Mock()
_HAS_JAX = False


class IdentityAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
"""

def __init__(self):
if not _HAS_JAX:
raise ImportError("JAX not available")

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(x)

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray],
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(deriv_by_ctrl_amps)


class UnaryAnalyticAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
Functions need to be compatible with jit.
(Includes that functions need to be pure
(i.e. output solely depends on input)).
"""

def __init__(self,
value_function: Callable[[float, ], float],
derivative_function: [Callable[[float, ], float]]):
if not _HAS_JAX:
raise ImportError("JAX not available")
self.value_function = jit(jnp.vectorize(value_function))
self.derivative_function = jit(jnp.vectorize(derivative_function))

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(self.value_function(x))

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x):
"""See base class. """
du_by_dx = self.derivative_function(x)
# du_by_dx shape: (n_time, n_ctrl)
# deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl)
# deriv_by_opt_par shape: (n_time, n_func, n_ctrl
# since the function is unary we have n_ctrl = n_amps
return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps)


class CustomAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
Functions need to be compatible with jit.
(Includes that functions need to be pure
(i.e. output solely depends on input)).
If derivative_function=None, autodiff is used.
t_to_vectorize: if value_function/derivative_function not yet
vectorized for num_t
"""

def __init__(
self,
value_function: Callable[[Union[np.ndarray, jnp.ndarray],],
Union[np.ndarray, jnp.ndarray]],
derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],],
Union[np.ndarray, jnp.ndarray]],
t_to_vectorize: bool = False
):
if not _HAS_JAX:
raise ImportError("JAX not available")
if t_to_vectorize == True:
self.value_function = jit(vmap(value_function),in_axes=(0,))
else:
self.value_function = jit(value_function)
if derivative_function is not None:
if t_to_vectorize == True:
self.derivative_function = jit(vmap(derivative_function),in_axes=(0,))
else:
self.derivative_function = jit(derivative_function)
else:
if t_to_vectorize == True:
def der_wrapper(x):
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2)
else:
def der_wrapper(x):
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2)
self.derivative_function = jit(der_wrapper)

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(self.value_function(x))

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray],
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
du_by_dx = self.derivative_function(x)
# du_by_dx: shape (time, par, ctrl)
# deriv_by_ctrl_amps: shape (time, func, ctrl)
# return: shape (time, func, par)

return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)
2,011 changes: 67 additions & 1,944 deletions qopt/cost_functions.py

Large diffs are not rendered by default.

665 changes: 0 additions & 665 deletions qopt/matrix.py

Large diffs are not rendered by default.

369 changes: 0 additions & 369 deletions qopt/noise.py
Original file line number Diff line number Diff line change
@@ -77,8 +77,6 @@

from qopt.util import deprecated

import random
from functools import partial

def bell_curve_1dim(x: Union[np.ndarray, float],
stdx: float) -> Union[np.ndarray, float]:
@@ -693,370 +691,3 @@ def plot_periodogram(self, n_average: int, scaling: str = 'density',
np.mean(spectral_density_or_spectrum, axis=0)[1:-1] -
self.noise_spectral_density(sample_frequencies)[1:-1])
return deviation_norm


###############################################################################

try:
import jax.numpy as jnp
from jax import jit, vmap
import jax
_HAS_JAX = True
except ImportError:
from unittest import mock
jit = mock.Mock()
jnp = mock.Mock()
vmap = mock.Mock()
jax = mock.Mock()
_HAS_JAX = False


@jit
def _inverse_cumulative_gaussian_distribution_function_jnp(
z: Union[float, np.array, jnp.ndarray], std: float, mean: float):
"""
Calculates the inverse cumulative function for the gaussian distribution.
Parameters
----------
z: Union[float, np.array, jnp.array]
Function value.
std: float
Standard deviation of the bell curve.
mean: float
Mean value of the gaussian distribution. Defaults to 0.
Returns
-------
selected_x: list of float
Noise samples.
"""
return std * jnp.sqrt(2) * jax.scipy.special.erfinv(2 * z - 1) + mean


@partial(jit,static_argnums=1)
def _sample_1dim_gaussian_distribution_jnp(std: float, n_samples: int, mean: float = 0)\
-> jnp.ndarray:
"""
Returns 'n_samples' samples from the one dimensional bell curve.
The samples are chosen such, that the integral over the bell curve between
two adjacent samples is always the same. The samples reproduce the correct
standard deviation only in the limit n_samples -> inf due to the
discreteness of the approximation. The error is to good approximation
1/n_samples.
Parameters
----------
std: float
Standard deviation of the bell curve.
n_samples: int
Number of samples returned.
mean: float
Mean value of the gaussian distribution. Defaults to 0.
Returns
-------
selected_x: numpy array of shape:(n_samples, )
Noise samples.
"""
z = jnp.linspace(start=0, stop=1, num=n_samples, endpoint=False)
z += 1 / (2 * n_samples)
# we distribute the total probability of 1 into n_samples equal parts.
# The z-values are in the center of each part.

x = _inverse_cumulative_gaussian_distribution_function_jnp(
z=jnp.expand_dims(z,0), std=jnp.expand_dims(std,1), mean=mean
)
# We use the inverse cumulative gaussian distribution to find the values x.
# The integral over the Gaussian distribution between x[i] and x[i+1]
# now always equals 1/n_samples.
return x


class NTGQuasiStaticJAX(NoiseTraceGenerator):
"""See docstring of class w/o JAX.
Additional parameter: seed: int, optional: seed for jax.random.PRNGKey
"""


def __init__(self, standard_deviation: List[float],
n_samples_per_trace: int,
n_traces: int = 1,
noise_samples: Optional[np.ndarray] = None,
always_redraw_samples: bool = True,
correct_std_for_discrete_sampling: bool = True,
sampling_mode: str = 'uncorrelated_deterministic',
seed: Optional[int] = None):
if not _HAS_JAX:
raise ImportError("JAX not available")
n_noise_operators = len(standard_deviation)
super().__init__(noise_samples=noise_samples,
n_samples_per_trace=n_samples_per_trace,
n_traces=n_traces,
n_noise_operators=n_noise_operators,
always_redraw_samples=always_redraw_samples)
self.standard_deviation = jnp.asarray(standard_deviation)

self.sampling_mode = sampling_mode
self.seed = seed if seed is not None else random.randint(0,2**32-1)
self.rnd_key_first = jax.random.PRNGKey(self.seed)
self.rnd_key_arr = [self.rnd_key_first]

if correct_std_for_discrete_sampling:
if self.n_traces == 1:
raise RuntimeWarning('Standard deviation cannot be estimated'
'for a single trace!')
elif self.sampling_mode == 'uncorrelated_deterministic':


n_std_dev = len(self.standard_deviation)
_noise_samples = _sample_1dim_gaussian_distribution_jnp(
self.standard_deviation, self._n_traces)
_noise_samples = jnp.broadcast_to(
jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)*
jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2),
(n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace))

actual_std = jnp.std(_noise_samples,axis=(1,2))
if jnp.any(actual_std < 1e-20):
raise RuntimeError('The standard deviation was '
'estimated close to 0!')
self.standard_deviation *= \
self.standard_deviation / actual_std

@property
def n_traces(self) -> int:
"""Number of traces.
The number of requested traces must be multiplied with the number of
standard deviations because if standard deviation is sampled
separately.
"""
if self._n_traces:
if self.sampling_mode == 'uncorrelated_deterministic':
return self._n_traces * len(self.standard_deviation)
elif self.sampling_mode == 'monte_carlo':
return self._n_traces
else:
raise ValueError('Unsupported sampling mode!')
else:
return self.noise_samples.shape[1]

def _sample_noise(self) -> None:
"""
Draws quasi static noise samples from a normal distribution.
Each noise contribution (corresponding to one noise operator) is
sampled separately. For each standard deviation n_traces traces are
calculated.
"""
if self.sampling_mode == 'uncorrelated_deterministic':

n_std_dev = len(self.standard_deviation)
_noise_samples = _sample_1dim_gaussian_distribution_jnp(
self.standard_deviation, self._n_traces)
self._noise_samples = jnp.broadcast_to(
jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)*
jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2),
(n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace))

elif self.sampling_mode == 'monte_carlo':

self._noise_samples = jnp.einsum(
'i,ijk->ijk',
self.standard_deviation,
jax.random.normal(
key=self.rnd_key_arr[-1],
shape=(len(self.standard_deviation),self.n_traces,1))
)
self._noise_samples = jnp.repeat(
self._noise_samples, self.n_samples_per_trace, axis=2)

self.rnd_key_arr.append(
jax.random.split(self.rnd_key_arr[-1],num=2)[1])

else:
raise ValueError('Unsupported sampling mode!')


def _fast_colored_noise_jnp(spectral_density: Callable, dt: float,
n_samples: int, output_shape: tuple, key,
r_power_of_two=False
) -> jnp.ndarray:
"""See docstring of function without _jnp"""
f_max = 1 / dt
f_nyquist = f_max / 2
s0 = 1 / f_nyquist
if r_power_of_two:
actual_n_samples = int(2 ** jnp.ceil(jnp.log2(n_samples)))
else:
actual_n_samples = int(n_samples)

delta_white = jax.random.normal(key,(*output_shape, actual_n_samples))
delta_white_ft = jnp.fft.rfft(delta_white, axis=-1)
# Only positive frequencies since FFT is real and therefore symmetric
f = jnp.linspace(0, f_nyquist, actual_n_samples // 2 + 1)
f = spectral_density(f[1:])
f = jnp.pad(f,((1, 0),))
delta_colored = jnp.fft.irfft(delta_white_ft * jnp.sqrt(f / s0),
n=actual_n_samples, axis=-1)
# the ifft takes r//2 + 1 inputs to generate r outputs

return delta_colored


class NTGColoredNoiseJAX(NoiseTraceGenerator):
"""See docstring of class w/o JAX.
Additional parameter: seed: int, optional: seed for jax.random.PRNGKey
"""

def __init__(self,
n_samples_per_trace: int,
noise_spectral_density: Callable,
dt: float,
n_traces: int = 1,
n_noise_operators: int = 1,
always_redraw_samples: bool = True,
low_frequency_extension_ratio: int = 1,
seed: Optional[int] = None):
if not _HAS_JAX:
raise ImportError("JAX not available")
super().__init__(n_traces=n_traces,
n_samples_per_trace=n_samples_per_trace,
noise_samples=None,
n_noise_operators=n_noise_operators,
always_redraw_samples=always_redraw_samples)
self.noise_spectral_density = noise_spectral_density
self.dt = dt
if low_frequency_extension_ratio < 1:
raise ValueError("The low frequency extension ratio must be "
"greater or equal to 1.")
self.low_frequency_extension_ratio = low_frequency_extension_ratio
if hasattr(dt, "__len__"):
raise ValueError('dt is supposed to be a scalar value!')

self.seed = seed if seed is not None else random.randint(0,2**32-1)
self.rnd_key_first = jax.random.PRNGKey(self.seed)
self.rnd_key_arr = [self.rnd_key_first]

def _sample_noise(self, **kwargs) -> None:
"""Samples noise from an arbitrary colored spectrum. """
if self._n_noise_operators is None:
raise ValueError('Please specify the number of noise operators!')
if self._n_traces is None:
raise ValueError('Please specify the number of noise traces!')
if self._n_samples_per_trace is None:
raise ValueError('Please specify the number of noise samples per'
'trace!')


noise_samples = _fast_colored_noise_jnp(
spectral_density=self.noise_spectral_density,
n_samples=
self.n_samples_per_trace * self.low_frequency_extension_ratio,
output_shape=(self.n_noise_operators, self.n_traces),
r_power_of_two=False,
dt=self.dt,
key=self.rnd_key_arr[-1])
self._noise_samples = noise_samples[:, :, :self.n_samples_per_trace]

self.rnd_key_arr.append(
jax.random.split(self.rnd_key_arr[-1],num=2)[1])

def plot_periodogram(self, n_average: int, scaling: str = 'density',
log_plot: Optional[str] = None, draw_plot=True):
"""Creates noise samples and plots the corresponding periodogram.
Parameters
----------
n_average: int
Number of Periodograms which are averaged.
scaling: {'density', 'spectrum'}, optional
If 'density' then the power spectral density in units of V**2/Hz is
plotted.
If 'spectral' then the power spectrum in units of V**2 is plotted.
Defaults to 'density'.
log_plot: {None, 'semilogy', 'semilogx', 'loglog'}, optional
If None, then the plot is not plotted logarithmically. If
'semilogy' only the y-axis is plotted logarithmically, if
'semilogx' only the x-axis is plotted logarithmically, if 'loglog'
both axis are plotted logarithmically. Defaults to None.
draw_plot: bool, optional
If true, then the periodogram is plotted. Defaults to True.
Returns
-------
deviation_norm: float
The vector norm of the deviation between the actual power spectral
density and the power spectral densitry found in the periodogram.
"""

noise_samples = fast_colored_noise(
spectral_density=self.noise_spectral_density,
n_samples=self.n_samples_per_trace,
output_shape=(n_average,),
r_power_of_two=False,
dt=self.dt
)

sample_frequencies, spectral_density_or_spectrum = signal.periodogram(
x=noise_samples,
fs=1 / self.dt,
return_onesided=True,
scaling=scaling,
axis=-1
)

if scaling == 'density':
y_label = 'Power Spectral Density (V**2/Hz)'
elif scaling == 'spectrum':
y_label = 'Power Spectrum (V**2)'
else:
raise ValueError('Unexpected scaling argument.')

if draw_plot:
plt.figure()

if log_plot is None:
plot_function = plt.plot
elif log_plot == 'semilogy':
plot_function = plt.semilogy
elif log_plot == 'semilogx':
plot_function = plt.semilogx
elif log_plot == 'loglog':
plot_function = plt.loglog
else:
raise ValueError('Unexpected plotting mode')

plot_function(sample_frequencies,
np.mean(spectral_density_or_spectrum, axis=0),
label='Periodogram')
plot_function(sample_frequencies,
self.noise_spectral_density(sample_frequencies),
label='Spectral Noise Density')

plt.ylabel(y_label)
plt.xlabel('Frequency (Hz)')
plt.legend(['Periodogram', 'Spectral Noise Density'])
plt.show()

deviation_norm = np.linalg.norm(
np.mean(spectral_density_or_spectrum, axis=0)[1:-1] -
self.noise_spectral_density(sample_frequencies)[1:-1])
return deviation_norm

814 changes: 2 additions & 812 deletions qopt/optimize.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion qopt/plotting.py
Original file line number Diff line number Diff line change
@@ -112,7 +112,7 @@ def plot_bloch_vector_evolution(
states = [
qt.Qobj((prop * initial_state).data) for prop in forward_propagators
]
a = np.empty((3, len(states)),dtype=complex) # for numerical integrity
a = np.empty((3, len(states)))
x, y, z = qt.sigmax(), qt.sigmay(), qt.sigmaz()
for i, state in enumerate(states):
a[:, i] = [qt.expect(x, state),
471 changes: 3 additions & 468 deletions qopt/simulator.py

Large diffs are not rendered by default.

1,398 changes: 11 additions & 1,387 deletions qopt/solver_algorithms.py

Large diffs are not rendered by default.

2,105 changes: 0 additions & 2,105 deletions qopt/solver_algorithms_copy_original.py

This file was deleted.

465 changes: 1 addition & 464 deletions qopt/transfer_function.py

Large diffs are not rendered by default.