10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
12
13
+ from packaging .version import parse
13
14
import pytest
14
15
import torch
15
16
import torch .nn as nn
19
20
from torchdyn .datasets import ToyDataset
20
21
from torchdyn .core import NeuralODE
21
22
from torchdyn .nn import GalLinear , GalConv2d , DepthCat , Augmenter , DataControl
22
- from torchdyn .numerics import odeint , Euler
23
+ from torchdyn .numerics import odeint , odeint_mshooting , Lorenz , Euler
23
24
24
25
from functools import partial
25
26
import copy
@@ -258,4 +259,60 @@ def forward(self, t, x, u, v, z, args={}):
258
259
t_eval , sol2 = odeprob (x0 , t_span = torch .linspace (0 , 5 , 10 ))
259
260
260
261
assert (sol1 == sol2 ).all ()
261
- grad (sol2 .sum (), x0 )
262
+ grad (sol2 .sum (), x0 )
263
+
264
+
265
+ @pytest .mark .skipif (parse (torch .__version__ ) < parse ("1.11.0" ),
266
+ reason = "adjoint support added in torch 1.11.0" )
267
+ def test_complex_ode ():
268
+ """Test odeint for complex numbers with a simple complex-valued ODE, corresponding
269
+ to Rabi oscillations of quantum two-level system."""
270
+ class Rabi (nn .Module ):
271
+ def __init__ (self , omega ):
272
+ super ().__init__ ()
273
+ self .sx = torch .tensor ([[0 , 1 ], [1 , 0 ]], dtype = torch .complex128 )
274
+ self .omega = omega
275
+ return
276
+ def forward (self , t , x ):
277
+ dx = - 1.0j * self .omega * self .sx @ x
278
+ dx += dx .adjoint ()
279
+ return dx
280
+
281
+ # Odeint parameters
282
+ omega = torch .randn (1 )
283
+ rabi = Rabi (omega )
284
+ tspan = torch .linspace (0. , 2. , 10 )
285
+
286
+ # Random initial state
287
+ x0 = torch .rand (2 , 2 , dtype = torch .complex128 )
288
+ x0 = 0.5 * (x0 + x0 .adjoint ()) / torch .real (x0 .trace ())
289
+ # Solve the ODE problem
290
+ t_eval , sol = odeint (f = rabi , x = x0 , t_span = tspan , solver = "dopri5" , atol = 1e-8 , rtol = 1e-6 )
291
+
292
+ # Expected solution
293
+ sx = torch .tensor ([[0 , 1 ], [1 , 0 ]], dtype = torch .complex128 )
294
+ si = torch .tensor ([[1 , 0 ], [0 , 1 ]], dtype = torch .complex128 )
295
+ U_t = torch .cos (omega * t_eval )[:, None , None ] * si
296
+ U_t += - 1j * torch .sin (omega * t_eval )[:, None , None ] * sx
297
+ sol_exp = U_t @ x0 @ U_t .adjoint ()
298
+
299
+ # Check result
300
+ assert torch .allclose (sol , sol_exp , rtol = 1e-5 , atol = 1e-5 )
301
+
302
+
303
+ @pytest .mark .parametrize ('solver' , ['mszero' ])
304
+ def test_odeint_mshooting (solver ):
305
+ x0 = torch .randn (8 , 3 ) + 15
306
+ t_span = torch .linspace (0 , 3 , 10 )
307
+ sys = Lorenz ()
308
+
309
+ odeint_mshooting (sys , x0 , t_span , solver = solver , fine_steps = 2 , maxiter = 4 )
310
+
311
+
312
+ @pytest .mark .parametrize ('solver' , ['euler' , 'rk4' , 'dopri5' ])
313
+ def test_odeint (solver ):
314
+ x0 = torch .randn (8 , 3 ) + 15
315
+ t_span = torch .linspace (0. , 2. , 10 )
316
+ sys = Lorenz ()
317
+
318
+ odeint (sys , x0 , t_span , solver = solver )
0 commit comments