Description
Description
I have been testing the numerical accuracy of automatic differentiation in NeuralPDEs.jl and observed significant numerical errors when differentiating through the loss function defined by NeuralPDEs. Specifically, the forward-mode Jacobian-vector products (JVPs) with the residual vector defining the loss function exhibit errors in the range of 1e-8, while other related computations (such as direct evaluations of the model) remain at the expected machine precision (~1e-16).
This suggests that some internal operations in NeuralPDEs' loss formulation might be inadvertently using lower-precision arithmetic (Float32 instead of Float64), or otherwise introducing unexpected numerical instability.
Additional context and expected behavior
To investigate this issue, I implemented functions that explicitly define the loss as
- For internal points:
$r_i(\theta) = \mathcal{D}u(x_i) - f(x_i)$ where$\mathcal{D}$ is the differential operator. - For boundary points:
$r_i(\theta) = u(x_i) - g(x_i)$ (with optional weighting constants).
I have attached a minimal working example below that demonstrates the issue. The key observations are:
- As a sanity check of my implementations, the relative error in evaluating
loss(θ)
versusloss_neuralpdes(θ)
is on the order of 1e-16 - The error when computing JVPs via AutoForwardDiff for direct neural network evaluations is also on the order of 1e-16.
- However, the error in JVPs computed for the residual function (which contributes to the loss) is 1e-8, indicating a significant loss of numerical precision.
Minimal Reproducible Example 👇
using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf
################################# AUXILIARY FUNCTIONS ####################################
using Optimization: OptimizationProblem
using NeuralPDE: NeuralPDE, PINNRepresentation, recursive_eltype, EltypeAdaptor, safe_get_device, GridTraining
using Statistics: Statistics, mean
# only for PhysicsInformedNN
function merge_strategy_with_residual_vector(pinnrep::PINNRepresentation,
strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function)
(; domains, eqs, bcs, dict_indvars, dict_depvars) = pinnrep
eltypeθ = recursive_eltype(pinnrep.flat_init_params)
adaptor = EltypeAdaptor{eltypeθ}()
train_sets = generate_training_sets(domains, strategy.dx, eqs, bcs, eltypeθ,
dict_indvars, dict_depvars)
# the points in the domain and on the boundary
pde_train_sets, bcs_train_sets = train_sets |> adaptor
pde_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(
datafree_pde_loss_function, pde_train_sets)]
bc_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]
return pde_loss_functions, bc_loss_functions
end
function get_residual_vector(
init_params, loss_function, train_set, eltype0, ::GridTraining; τ = nothing)
init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
train_set = train_set |> safe_get_device(init_params) |> EltypeAdaptor{eltype0}()
return θ -> loss_function(train_set, θ)
end
function get_full_residual(prob::OptimizationProblem, symprob::PINNRepresentation)
# Get PDE and BC residuals
pde_residuals, bc_residuals = merge_strategy_with_residual_vector(symprob,
symprob.strategy, symprob.loss_functions.datafree_pde_loss_functions, symprob.loss_functions.datafree_bc_loss_functions)
# Setup weights for PDE and BCs
flat_init_params = prob.u0
adaloss = discretization.adaptive_loss
@assert isnothing(adaloss) # FIXME: Assuming no adaloss
num_additional_loss = 0
adaloss === nothing && (adaloss = NonAdaptiveLoss{eltype(flat_init_params)}())
# setup for all adaptive losses
num_pde_losses = length(pde_residuals)
num_bc_losses = length(bc_residuals)
adaloss_T = eltype(adaloss.pde_loss_weights)
# this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions
adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights
adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights
adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .*
adaloss.additional_loss_weights
function full_residual(θ)
pde_losses = [pde_residual(θ) for pde_residual in pde_residuals]
bc_losses = [bc_residual(θ) for bc_residual in bc_residuals]
weighted_pde_losses = sqrt.(adaloss.pde_loss_weights) .* pde_losses ./ sqrt.(length.(pde_losses))
weighted_bc_losses = sqrt.(adaloss.bc_loss_weights) .* bc_losses ./ sqrt.(length.(bc_losses))
# full_res = hcat(Iterators.flatten((weighted_pde_losses, weighted_bc_losses))...)
full_res = hcat(hcat(weighted_pde_losses...), hcat(weighted_bc_losses...))
return full_res
end
return full_residual
end
function get_quadpoints(symprob::PINNRepresentation, strategy::GridTraining)
(; domains, eqs, dict_indvars, dict_depvars) = symprob
eltypeθ = recursive_eltype(symprob.flat_init_params)
train_sets = hcat(generate_training_sets(domains, strategy.dx, eqs, [], eltypeθ,
dict_indvars, dict_depvars)[1]...)
return train_sets
end
################################# NEURALPDES TUTORIAL ####################################
@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0
# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))
analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]
# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
x ∈ Interval(x_min, x_max),
y ∈ Interval(y_min, y_max)]
# Neural network
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, 1))
strategy = GridTraining(0.1)
ps, st = Lux.setup(Random.default_rng(), chain)
ps = ps |> ComponentArray .|> Float64
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)
@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)
callback = function (p, l)
println("Current loss is: $l")
return false
end
# Definition of the residual vector
residual = get_full_residual(prob, symprob)
loss = θ -> sum(abs2, residual(θ))
loss_neuralpdes = θ -> prob.f(θ, prob.p)
################################# TESTS ON THE ACCURACY ####################################
using ForwardDiff, Zygote, DifferentiationInterface, LinearAlgebra
# Sanity check
θ = prob.u0
rel_err = (loss_neuralpdes(θ) - loss(θ)) / loss_neuralpdes(θ)
println("Error on the loss: \t\t\t $rel_err") # In the order of 1e-16
# Test of AutoForwardDiff for differentiation of model evaluations
x = get_quadpoints(symprob, strategy)
fun = ps -> chain(x, ps, st)[1]
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(fun, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
fun,
AutoForwardDiff(),
θ,
(v,),
)[1]
println("AutoForwardDiff error on model jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-16
# Check with JVPs
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(residual, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
residual,
AutoForwardDiff(),
θ,
(v,),
)[1]
println("AutoForwardDiff error on residual jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-8!!!