Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ include("eltype_matching.jl")

include("pinn_types.jl")
include("symbolic_utilities.jl")
include("gpu_utils.jl")
using .GPUUtils: transform_power_ops, should_apply_gpu_transform
include("training_strategies.jl")
include("adaptive_losses.jl")

Expand Down
7 changes: 7 additions & 0 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars)
pinnrep, eqs; bc_indvars, eq_params,
param_estim, default_p
)

# GPU-only rewrite of integer powers to multiplication
# to ensure stable symbolic differentiation and GPU AD (Issue #914).
if should_apply_gpu_transform(pinnrep.init_params)
expr_loss_function = transform_power_ops(expr_loss_function)
end

u = get_u()
_loss_function = @RuntimeGeneratedFunction(expr_loss_function)
return (cord, θ) -> _loss_function(cord, θ, phi, derivative, integral, u, default_p)
Expand Down
96 changes: 96 additions & 0 deletions src/gpu_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
module GPUUtils

using Symbolics, SymbolicUtils
using MLDataDevices: get_device, AbstractGPUDevice

export transform_power_ops, should_apply_gpu_transform

"""
transform_power_ops(expr)

Rewrite integer powers (e.g. `u^2`, `u^3`) into explicit multiplication.

This exists to avoid NaNs observed on GPU during training for expressions
like `u(x)^2` and `Dx(u(x)^3)` (see #914). It is not intended as a
performance optimization.
"""

function transform_power_ops(expr)
count = Ref(0)

# Extract base expression from ModelingToolkit wrapper if present
was_num = expr isa Symbolics.Num
base_expr = was_num ? Symbolics.unwrap(expr) : expr

transformed = Symbolics.postwalk(base_expr) do node
# Process BasicSymbolic nodes (symbolic expressions in Symbolics v6+)
if node isa SymbolicUtils.BasicSymbolic
op = Symbolics.operation(node)
args = Symbolics.arguments(node)

# Match power operations
if op === ^
base = args[1]
exponent = args[2]

# Transform only when exponent is a literal integer or integer-valued number
if exponent isa Integer || (exponent isa Number && exponent == floor(exponent))
n = Int(exponent)
count[] += 1

if n == 0
return 1
elseif n == 1
return base
elseif n == 2
# Use SymbolicUtils.term to prevent auto-simplification
return SymbolicUtils.term(*, base, base)
elseif n == 3
return SymbolicUtils.term(*, base, base, base)
else
# Unroll arbitrary exponents: u^n → u * u * ... * u (n factors)
factors = [base for _ in 1:n]
return SymbolicUtils.term(*, factors...)
end
end
end
end

return node
end

# Debug logging
if count[] > 0 && get(ENV, "NEURALPDE_DEBUG", "0") == "1"
@info "GPU power transformation: expanded $(count[]) power operations to multiplication chains"
end

# Re-attach ModelingToolkit wrapper if the input was wrapped
return was_num ? Symbolics.Num(transformed) : transformed
end

"""
should_apply_gpu_transform(init_params)

Return `true` when GPU-specific symbolic rewrites should be applied

This gates the power-rewriting logic to GPU code paths only (see #914)
"""

function should_apply_gpu_transform(init_params)
init_params === nothing && return false

# Allow explicit override via environment variable for development and testing
if get(ENV, "NEURALPDE_GPU_POWER_REWRITE", "0") == "1"
return true
end

# Detect GPU devices using the MLDataDevices.jl abstraction
try
return get_device(init_params) isa AbstractGPUDevice
catch
# If device detection fails, default to CPU mode (no transformation)
return false
end
end

end # module GPUUtils
162 changes: 162 additions & 0 deletions test/gpu_nonlinear_tests.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't in the runtests so it won't be ran.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dr Chris, Thank you for pointing this out. I might be missing something here.
I added them as @testItems and was relying on ReTestItems to pick them up via the :cuda tags.
If you’d prefer them to be explicitly included in runtests.jl or moved alongside the existing CUDA PDE tests, I’d be happy to adjust.

Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
@testsetup module GPUNonlinearTestSetup
using NeuralPDE
using Symbolics: expand_derivatives
using Lux, Optimization, OptimizationOptimisers
using Random, ComponentArrays, LuxCUDA, CUDA
using NeuralPDE.GPUUtils

function callback(p, l)
if p.iter == 1 || p.iter % 100 == 0
println("GPU Nonlinear Test - Loss: $l after $(p.iter) iterations")
end
return false
end

const gpud = CUDA.functional() ? gpu_device() : nothing
export gpud, callback, transform_power_ops, should_apply_gpu_transform
end

@testitem "Symbolic Power Transformation" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using Symbolics
using ModelingToolkit

@variables x u(..)
Dx = Differential(x)

# Test basic transformation: u^2 → u * u
expr2 = u(x)^2
transformed2 = transform_power_ops(expr2)
@test Symbolics.simplify(transformed2 - u(x)*u(x)) == 0

# Test: u^3 → u * u * u
expr3 = u(x)^3
transformed3 = transform_power_ops(expr3)
@test Symbolics.simplify(transformed3 - u(x)*u(x)*u(x)) == 0

# Test derivative compatibility: symbolic differentiation should work after transformation
expr_deriv = Dx(u(x)^3)
transformed_deriv = transform_power_ops(expr_deriv)
expanded = expand_derivatives(transformed_deriv)

# Should not crash and should produce a valid expression
@test !isnothing(expanded)
@test expanded isa Union{Num, SymbolicUtils.Term}

# Test non-integer exponents: should not be transformed
expr_nonint = u(x)^2.5
transformed_nonint = transform_power_ops(expr_nonint)
@test Symbolics.simplify(transformed_nonint - u(x)^2.5) == 0

# Test edge cases: u^0 = 1, u^1 = u
@test transform_power_ops(u(x)^1) == u(x)
@test transform_power_ops(u(x)^0) == 1
end

@testitem "GPU Device Detection" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using ComponentArrays
using CUDA

# Test with nothing: should return false
@test should_apply_gpu_transform(nothing) == false

# Test with CPU parameters: should return false
cpu_params = ComponentArray(a = [1.0, 2.0, 3.0])
@test should_apply_gpu_transform(cpu_params) == false

# Test with GPU parameters (if CUDA available)
if CUDA.functional()
gpu_params = ComponentArray(a = [1.0, 2.0, 3.0]) |> gpud
@test should_apply_gpu_transform(gpu_params) == true
end
end

@testitem "Nonlinear PDE u^2 - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using CUDA
using Random
import DomainSets: Interval
using ModelingToolkit
using Lux
using ComponentArrays
using NeuralPDE
using Optimization
using OptimizationOptimisers

CUDA.functional() || return # Skip if CUDA not available

Random.seed!(100)

@parameters x
@variables u(..)
Dx = Differential(x)

# Simple nonlinear PDE: u^2 = 0 with boundary condition u(0) = 0
# This tests the symbolic transformation of power operations in PDE equations
eq = u(x)^2 ~ 0.0
bcs = [u(0.0) ~ 0.0]
domains = [x ∈ Interval(0.0, 1.0)]

# Neural network: small configuration for unit testing
inner = 10
chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1))

strategy = GridTraining(0.1)
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64

discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)])
prob = discretize(pde_system, discretization)

# Solve: power transformation should enable differentiation
res = solve(prob, Adam(0.01); maxiters = 200, callback = callback)

# Verify solution integrity: no NaN or Inf values
@test !any(isnan, res.u)
@test all(isfinite, res.u)
end

@testitem "Nonlinear PDE Dx(u^3) - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using CUDA
using Random
import DomainSets: Interval
using ModelingToolkit
using Lux
using ComponentArrays
using NeuralPDE
using Optimization
using OptimizationOptimisers

CUDA.functional() || return # Skip if CUDA not available

Random.seed!(200)

@parameters x
@variables u(..)
Dx = Differential(x)

# Test case from issue #914: Dx(u^3)
# This case produced NaN in automatic differentiation prior to power operation expansion
# The fix transforms u^3 → u * u * u, enabling chain rule application
eq = Dx(u(x)^3) ~ 0.0
bcs = [u(0.0) ~ 0.0]
domains = [x ∈ Interval(0.0, 1.0)]

# Neural network: small configuration for unit testing
inner = 10
chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1))

strategy = QuasiRandomTraining(1000)
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64

discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)])
prob = discretize(pde_system, discretization)

# Solve: this was the case that generated NaN before the fix
res = solve(prob, Adam(0.01); maxiters = 200, callback = callback)

# Verify solution: the main assertion that the fix works
@test !any(isnan, res.u)
@test all(isfinite, res.u)
end