1
+ using DiffEqFlux
2
+ using OrdinaryDiffEq
3
+ using Optimisers
4
+ using Flux
5
+
6
+ using PyCall
7
+ using PyCallChainRules. Torch: TorchModuleWrapper, torch
8
+
9
+ u0 = Float32[2. ; 0. ]
10
+ datasize = 30
11
+ tspan = (0.0f0 , 1.5f0 )
12
+
13
+ function trueODEfunc (du,u,p,t)
14
+ true_A = [- 0.1 2.0 ; - 2.0 - 0.1 ]
15
+ du .= ((u.^ 3 )' true_A)'
16
+ end
17
+ t = range (tspan[1 ],tspan[2 ],length= datasize)
18
+ prob = ODEProblem (trueODEfunc,u0,tspan)
19
+ ode_data = Array (solve (prob,Tsit5 (),saveat= t))
20
+
21
+ torch_module = torch. nn. Sequential (
22
+ torch. nn. Linear (2 , 50 ), torch. nn. Tanh (),
23
+ torch. nn. Linear (50 , 2 ), torch. nn. Tanh (),
24
+ )
25
+ # Mix of Flux layers and Torch layers
26
+ jlmod = Chain (Dense (2 , 2 , tanh), TorchModuleWrapper (torch_module), Dense (2 , 2 ,))
27
+ p, re = Optimisers. destructure (jlmod)
28
+
29
+ dudt (u, p, t) = re (p)(u)
30
+ prob = ODEProblem (dudt, u0, tspan)
31
+
32
+ function predict_n_ode (p)
33
+ Array (solve (prob,Tsit5 (),u0= u0,p= p,saveat= t))
34
+ end
35
+
36
+ function loss_n_ode (p)
37
+ pred = predict_n_ode (p)
38
+ loss = sum (abs2,ode_data .- pred)
39
+ loss
40
+ end
41
+
42
+ loss_n_ode (p)
43
+
44
+ data = Iterators. repeated ((), 1000 )
45
+
46
+
47
+ @info " before" loss_n_ode (p)
48
+
49
+ function train (p;nsteps= 100 )
50
+ opt = Optimisers. ADAM (0.01 )
51
+ state = Optimisers. setup (opt, p)
52
+
53
+ for i in 1 : nsteps
54
+ gs, = Flux. gradient (p) do ps
55
+ loss_n_ode (ps)
56
+ end
57
+ state, p = Optimisers. update (state, p, gs)
58
+ end
59
+ return p
60
+ end
61
+
62
+ newp = train (p)
63
+
64
+ @info " after" loss_n_ode (newp)
0 commit comments