Skip to content

Making prepared gradients callable #719

@henrik-wolf

Description

@henrik-wolf

I have been playing around with DifferentiationInterface to automatically calculate derivatives of various (interaction) potentials to be used when solving a differential equation.
Doing

# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2

function oscillator!(du, u, p, t)
	du[1:3] .= u[4:6]

	# probably inefficient
	grad = gradient(x->V(x...; a=p.a), AutoForwardDiff(), u[1:3])
	du[4:6] = -grad
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (;a=1.0)), reltol=1e-6)

seems to be fairly slow. As such, I would like to prepare and store the gradient, and pass it as a parameter:

# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2

# specific to one run
a = 1.0
single_arg_pot = x->V(x...; a=a)
prep_grad = prepare_gradient(single_arg_pot, AutoForwardDiff(), zeros(3))

function oscillator!(du, u, p, t)
	du[1:3] .= u[4:6]

	# probably inefficient
	grad = gradient(p.pot, p.prep, AutoForwardDiff(), u[1:3])
	du[4:6] = -grad
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (pot=single_arg_pot, prep=prep_grad)), reltol=1e-6)

There seems, however, to be some redundancy in the way I have to call the gradient with a prepared gradient: the function has to be the same, the backend as well.

So quickly hacked together a callable gradient that does all this automatically:

struct GradientStorage{F, GF}
	func::F
	prep::GF
	store::Vector{Float64}
end

function GradientStorage(f, n)
	store = zeros(n)
	prep = prepare_gradient(f, AutoForwardDiff(), store)
	return GradientStorage(f, prep, store)
end

function (storage::GradientStorage{A,B})(x) where {A,B}
	g = gradient!(storage.func, storage.store, storage.prep, AutoForwardDiff(), x)
	return storage.store
end

with which I can now do

# some potential
V(x,y,z;a) = a*(x-y)^2 * z^2

# specific to one run
a=1.0
prep_grad = GradientStorage(x->V(x...; a=a), 3)

function oscillator!(du, u, p, t)
	du[1:3] = dx = u[4:6]

	du[4:6] = -p.force(u[1:3])
end
solve(ODEProblem(oscillator!, rand(6), (0.0,100.0), (;force=prep_grad), reltol=1e-6)

Would it make sense to add this functionality (with a less ad hoc cobbled together interface) to this package directly? Are there any strange type stability or other problems that I am currently not aware of why this has not been done already?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions