-
Notifications
You must be signed in to change notification settings - Fork 28
Description
I have been playing around with DifferentiationInterface to automatically calculate derivatives of various (interaction) potentials to be used when solving a differential equation.
Doing
# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2
function oscillator!(du, u, p, t)
du[1:3] .= u[4:6]
# probably inefficient
grad = gradient(x->V(x...; a=p.a), AutoForwardDiff(), u[1:3])
du[4:6] = -grad
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (;a=1.0)), reltol=1e-6)
seems to be fairly slow. As such, I would like to prepare and store the gradient, and pass it as a parameter:
# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2
# specific to one run
a = 1.0
single_arg_pot = x->V(x...; a=a)
prep_grad = prepare_gradient(single_arg_pot, AutoForwardDiff(), zeros(3))
function oscillator!(du, u, p, t)
du[1:3] .= u[4:6]
# probably inefficient
grad = gradient(p.pot, p.prep, AutoForwardDiff(), u[1:3])
du[4:6] = -grad
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (pot=single_arg_pot, prep=prep_grad)), reltol=1e-6)
There seems, however, to be some redundancy in the way I have to call the gradient with a prepared gradient: the function has to be the same, the backend as well.
So quickly hacked together a callable gradient that does all this automatically:
struct GradientStorage{F, GF}
func::F
prep::GF
store::Vector{Float64}
end
function GradientStorage(f, n)
store = zeros(n)
prep = prepare_gradient(f, AutoForwardDiff(), store)
return GradientStorage(f, prep, store)
end
function (storage::GradientStorage{A,B})(x) where {A,B}
g = gradient!(storage.func, storage.store, storage.prep, AutoForwardDiff(), x)
return storage.store
end
with which I can now do
# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2
# specific to one run
a=1.0
prep_grad = GradientStorage(x->V(x...; a=a), 3)
function oscillator!(du, u, p, t)
du[1:3] = dx = u[4:6]
du[4:6] = -p.force(u[1:3])
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (;force=prep_grad), reltol=1e-6)
Would it make sense to add this functionality (with a less ad hoc cobbled together interface) to this package directly? Are there any strange type stability or other problems that I am currently not aware of why this has not been done already?