Skip to content

Potential precision loss in AutoDiff through the loss function #931

Open
@IvanBioli

Description

@IvanBioli

Description
I have been testing the numerical accuracy of automatic differentiation in NeuralPDEs.jl and observed significant numerical errors when differentiating through the loss function defined by NeuralPDEs. Specifically, the forward-mode Jacobian-vector products (JVPs) with the residual vector defining the loss function exhibit errors in the range of 1e-8, while other related computations (such as direct evaluations of the model) remain at the expected machine precision (~1e-16).

This suggests that some internal operations in NeuralPDEs' loss formulation might be inadvertently using lower-precision arithmetic (Float32 instead of Float64), or otherwise introducing unexpected numerical instability.

Additional context and expected behavior
To investigate this issue, I implemented functions that explicitly define the loss as $L(\theta) = \sum_{i=1}^n (r_i(\theta))^2$, where $\theta$ are the neural network parameters, and the residuals are defined as:

  • For internal points: $r_i(\theta) = \mathcal{D}u(x_i) - f(x_i)$ where $\mathcal{D}$ is the differential operator.
  • For boundary points: $r_i(\theta) = u(x_i) - g(x_i)$ (with optional weighting constants).

I have attached a minimal working example below that demonstrates the issue. The key observations are:

  1. As a sanity check of my implementations, the relative error in evaluating loss(θ) versus loss_neuralpdes(θ) is on the order of 1e-16
  2. The error when computing JVPs via AutoForwardDiff for direct neural network evaluations is also on the order of 1e-16.
  3. However, the error in JVPs computed for the residual function (which contributes to the loss) is 1e-8, indicating a significant loss of numerical precision.

Minimal Reproducible Example 👇

using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf

################################# AUXILIARY FUNCTIONS ####################################
using Optimization: OptimizationProblem
using NeuralPDE: NeuralPDE, PINNRepresentation, recursive_eltype, EltypeAdaptor, safe_get_device, GridTraining
using Statistics: Statistics, mean

# only for PhysicsInformedNN
function merge_strategy_with_residual_vector(pinnrep::PINNRepresentation,
        strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function)
    (; domains, eqs, bcs, dict_indvars, dict_depvars) = pinnrep
    eltypeθ = recursive_eltype(pinnrep.flat_init_params)
    adaptor = EltypeAdaptor{eltypeθ}()

    train_sets = generate_training_sets(domains, strategy.dx, eqs, bcs, eltypeθ,
        dict_indvars, dict_depvars)

    # the points in the domain and on the boundary
    pde_train_sets, bcs_train_sets = train_sets |> adaptor
    pde_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
                        for (_loss, _set) in zip(
        datafree_pde_loss_function, pde_train_sets)]

    bc_loss_functions = [get_residual_vector(pinnrep, _loss, _set, eltypeθ, strategy)
                        for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]

    return pde_loss_functions, bc_loss_functions
end

function get_residual_vector(
        init_params, loss_function, train_set, eltype0, ::GridTraining; τ = nothing)
    init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
    train_set = train_set |> safe_get_device(init_params) |> EltypeAdaptor{eltype0}()
    return θ -> loss_function(train_set, θ)
end


function get_full_residual(prob::OptimizationProblem, symprob::PINNRepresentation)
    # Get PDE and BC residuals
    pde_residuals, bc_residuals = merge_strategy_with_residual_vector(symprob,
        symprob.strategy, symprob.loss_functions.datafree_pde_loss_functions, symprob.loss_functions.datafree_bc_loss_functions)

    # Setup weights for PDE and BCs
    flat_init_params = prob.u0
    adaloss = discretization.adaptive_loss
    @assert isnothing(adaloss) # FIXME: Assuming no adaloss
    num_additional_loss = 0

    adaloss === nothing && (adaloss = NonAdaptiveLoss{eltype(flat_init_params)}())
    
    # setup for all adaptive losses
    num_pde_losses = length(pde_residuals)
    num_bc_losses = length(bc_residuals)
    adaloss_T = eltype(adaloss.pde_loss_weights)

    # this will error if the user has provided a number of initial weights that is more than 1 and doesn't match the number of loss functions
    adaloss.pde_loss_weights = ones(adaloss_T, num_pde_losses) .* adaloss.pde_loss_weights
    adaloss.bc_loss_weights = ones(adaloss_T, num_bc_losses) .* adaloss.bc_loss_weights
    adaloss.additional_loss_weights = ones(adaloss_T, num_additional_loss) .*
                                    adaloss.additional_loss_weights

    function full_residual(θ)
        pde_losses = [pde_residual(θ) for pde_residual in pde_residuals]
        bc_losses = [bc_residual(θ) for bc_residual in bc_residuals]

        weighted_pde_losses = sqrt.(adaloss.pde_loss_weights) .* pde_losses ./ sqrt.(length.(pde_losses))
        weighted_bc_losses = sqrt.(adaloss.bc_loss_weights) .* bc_losses ./ sqrt.(length.(bc_losses))

        # full_res = hcat(Iterators.flatten((weighted_pde_losses, weighted_bc_losses))...)
        full_res = hcat(hcat(weighted_pde_losses...), hcat(weighted_bc_losses...))
        return full_res
    end

    return full_residual
end

function get_quadpoints(symprob::PINNRepresentation, strategy::GridTraining)
    (; domains, eqs, dict_indvars, dict_depvars) = symprob
    eltypeθ = recursive_eltype(symprob.flat_init_params)

    train_sets = hcat(generate_training_sets(domains, strategy.dx, eqs, [], eltypeθ,
        dict_indvars, dict_depvars)[1]...)
    return train_sets
end

################################# NEURALPDES TUTORIAL ####################################
@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t  Interval(t_min, t_max),
    x  Interval(x_min, x_max),
    y  Interval(y_min, y_max)]

# Neural network
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, 1))

strategy = GridTraining(0.1)
ps, st = Lux.setup(Random.default_rng(), chain)
ps = ps |> ComponentArray .|> Float64
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

# Definition of the residual vector
residual = get_full_residual(prob, symprob)
loss = θ -> sum(abs2, residual(θ))
loss_neuralpdes = θ -> prob.f(θ, prob.p)

################################# TESTS ON THE ACCURACY ####################################
using ForwardDiff, Zygote, DifferentiationInterface, LinearAlgebra

# Sanity check
θ = prob.u0
rel_err = (loss_neuralpdes(θ) - loss(θ)) / loss_neuralpdes(θ)
println("Error on the loss: \t\t\t $rel_err") # In the order of 1e-16

# Test of AutoForwardDiff for differentiation of model evaluations
x = get_quadpoints(symprob, strategy)
fun = ps -> chain(x, ps, st)[1]
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(fun, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
    fun,
    AutoForwardDiff(),
    θ,
    (v,),
)[1]
println("AutoForwardDiff error on model jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-16

# Check with JVPs
v = randn(length(θ))
J_fwd = ForwardDiff.jacobian(residual, θ)
jvp_explicit = J_fwd * v
jvp_pushforward = DifferentiationInterface.pushforward(
    residual,
    AutoForwardDiff(),
    θ,
    (v,),
)[1]
println("AutoForwardDiff error on residual jvp:\t $(norm(jvp_explicit - jvp_pushforward[:]) / norm(jvp_explicit))") # In the order of 1e-8!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions