Skip to content

Commit 0984bf2

Browse files
committed
add example for mix of flux and torch layers with diffeq; currently fails
1 parent c9cf723 commit 0984bf2

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

Diff for: examples/diffeqflux/simple_mix_node.jl

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using DiffEqFlux
2+
using OrdinaryDiffEq
3+
using Optimisers
4+
using Flux
5+
6+
using PyCall
7+
using PyCallChainRules.Torch: TorchModuleWrapper, torch
8+
9+
u0 = Float32[2.; 0.]
10+
datasize = 30
11+
tspan = (0.0f0, 1.5f0)
12+
13+
function trueODEfunc(du,u,p,t)
14+
true_A = [-0.1 2.0; -2.0 -0.1]
15+
du .= ((u.^3)'true_A)'
16+
end
17+
t = range(tspan[1],tspan[2],length=datasize)
18+
prob = ODEProblem(trueODEfunc,u0,tspan)
19+
ode_data = Array(solve(prob,Tsit5(),saveat=t))
20+
21+
torch_module = torch.nn.Sequential(
22+
torch.nn.Linear(2, 50), torch.nn.Tanh(),
23+
torch.nn.Linear(50, 2), torch.nn.Tanh(),
24+
)
25+
# Mix of Flux layers and Torch layers
26+
jlmod = Chain(Dense(2, 2, tanh), TorchModuleWrapper(torch_module), Dense(2, 2,))
27+
p, re = Optimisers.destructure(jlmod)
28+
29+
dudt(u, p, t) = re(p)(u)
30+
prob = ODEProblem(dudt, u0, tspan)
31+
32+
function predict_n_ode(p)
33+
Array(solve(prob,Tsit5(),u0=u0,p=p,saveat=t))
34+
end
35+
36+
function loss_n_ode(p)
37+
pred = predict_n_ode(p)
38+
loss = sum(abs2,ode_data .- pred)
39+
loss
40+
end
41+
42+
loss_n_ode(p)
43+
44+
data = Iterators.repeated((), 1000)
45+
46+
47+
@info "before" loss_n_ode(p)
48+
49+
function train(p;nsteps=100)
50+
opt = Optimisers.ADAM(0.01)
51+
state = Optimisers.setup(opt, p)
52+
53+
for i in 1:nsteps
54+
gs, = Flux.gradient(p) do ps
55+
loss_n_ode(ps)
56+
end
57+
state, p = Optimisers.update(state, p, gs)
58+
end
59+
return p
60+
end
61+
62+
newp = train(p)
63+
64+
@info "after" loss_n_ode(newp)

0 commit comments

Comments
 (0)