Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ include("eltype_matching.jl")

include("pinn_types.jl")
include("symbolic_utilities.jl")
include("gpu_utils.jl")
using .GPUUtils: transform_power_ops, should_apply_gpu_transform
include("training_strategies.jl")
include("adaptive_losses.jl")

Expand Down
7 changes: 7 additions & 0 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars)
pinnrep, eqs; bc_indvars, eq_params,
param_estim, default_p
)

# GPU-only rewrite of integer powers to multiplication
# to ensure stable symbolic differentiation and GPU AD (Issue #914).
if should_apply_gpu_transform(pinnrep.init_params)
expr_loss_function = transform_power_ops(expr_loss_function)
end

u = get_u()
_loss_function = @RuntimeGeneratedFunction(expr_loss_function)
return (cord, θ) -> _loss_function(cord, θ, phi, derivative, integral, u, default_p)
Expand Down
111 changes: 111 additions & 0 deletions src/gpu_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
module GPUUtils

using Symbolics, SymbolicUtils
using MLDataDevices: get_device, AbstractGPUDevice

export transform_power_ops, should_apply_gpu_transform

"""
transform_power_ops(expr)

Transform integer power operations into explicit multiplication chains
compatible with symbolic differentiation.

This function rewrites expressions of the form `u^n` (where `n` is a positive
integer) into equivalent multiplication expressions `u * u * ... * u` (n times).
This transformation enables automatic differentiation through the Symbolics.jl
chain rule without requiring special-cased derivative rules for power operations.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem like it would generate more efficient code?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re right. yes
I did that to avoid NaNs we were seeing on GPU backward passes for expressions like u(x)^2 and Dx(u(x)^3) (issue #914).
I'll update the comment to make the intent clearer, but would you prefer to handle this error a different way?


Example:
- `u^2` → `u * u`
- `u^3` → `u * u * u`
- `u^4` → `u * u * u * u`
"""
function transform_power_ops(expr)
count = Ref(0)

# Extract base expression from ModelingToolkit wrapper if present
was_num = expr isa Symbolics.Num
base_expr = was_num ? Symbolics.unwrap(expr) : expr

transformed = Symbolics.postwalk(base_expr) do node
# Process BasicSymbolic nodes (symbolic expressions in Symbolics v6+)
if node isa SymbolicUtils.BasicSymbolic
op = Symbolics.operation(node)
args = Symbolics.arguments(node)

# Match power operations
if op === ^
base = args[1]
exponent = args[2]

# Transform only when exponent is a literal integer or integer-valued number
if exponent isa Integer || (exponent isa Number && exponent == floor(exponent))
n = Int(exponent)
count[] += 1

if n == 0
return 1
elseif n == 1
return base
elseif n == 2
# Use SymbolicUtils.term to prevent auto-simplification
return SymbolicUtils.term(*, base, base)
elseif n == 3
return SymbolicUtils.term(*, base, base, base)
else
# Unroll arbitrary exponents: u^n → u * u * ... * u (n factors)
factors = [base for _ in 1:n]
return SymbolicUtils.term(*, factors...)
end
end
end
end

return node
end

# Debug logging
if count[] > 0 && get(ENV, "NEURALPDE_DEBUG", "0") == "1"
@info "GPU power transformation: expanded $(count[]) power operations to multiplication chains"
end

# Re-attach ModelingToolkit wrapper if the input was wrapped
return was_num ? Symbolics.Num(transformed) : transformed
end

"""
should_apply_gpu_transform(init_params)

Determine whether symbolic power operation transformation should be applied
based on the target computational device.

This function detects if `init_params` corresponds to GPU device parameters.
When GPU device is detected, power operations are expanded into multiplication
chains to enable efficient automatic differentiation on GPU accelerators.

Arguments:
- `init_params`: Model initialization parameters, typically from a Lux neural network

Returns:
- `true` if parameters are allocated on a GPU device
- `false` otherwise, or if `init_params` is `nothing`
"""
function should_apply_gpu_transform(init_params)
init_params === nothing && return false

# Allow explicit override via environment variable for development and testing
if get(ENV, "NEURALPDE_GPU_POWER_REWRITE", "0") == "1"
return true
end

# Detect GPU devices using the MLDataDevices.jl abstraction
try
return get_device(init_params) isa AbstractGPUDevice
catch
# If device detection fails, default to CPU mode (no transformation)
return false
end
end

end # module GPUUtils
144 changes: 144 additions & 0 deletions test/gpu_nonlinear_tests.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't in the runtests so it won't be ran.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dr Chris, Thank you for pointing this out. I might be missing something here.
I added them as @testItems and was relying on ReTestItems to pick them up via the :cuda tags.
If you’d prefer them to be explicitly included in runtests.jl or moved alongside the existing CUDA PDE tests, I’d be happy to adjust.

Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
@testsetup module GPUNonlinearTestSetup
using NeuralPDE
using Symbolics: expand_derivatives
using Lux, Optimization, OptimizationOptimisers
using Random, ComponentArrays, LuxCUDA, CUDA
using NeuralPDE.GPUUtils

function callback(p, l)
if p.iter == 1 || p.iter % 100 == 0
println("GPU Nonlinear Test - Loss: $l after $(p.iter) iterations")
end
return false
end

const gpud = CUDA.functional() ? gpu_device() : nothing
export gpud, callback, transform_power_ops, should_apply_gpu_transform
end

@testitem "Symbolic Power Transformation" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using Symbolics

@variables x u(..)
Dx = Differential(x)

# Test basic transformation: u^2 → u * u
expr2 = u(x)^2
transformed2 = transform_power_ops(expr2)
@test Symbolics.simplify(transformed2 - u(x)*u(x)) == 0

# Test: u^3 → u * u * u
expr3 = u(x)^3
transformed3 = transform_power_ops(expr3)
@test Symbolics.simplify(transformed3 - u(x)*u(x)*u(x)) == 0

# Test derivative compatibility: symbolic differentiation should work after transformation
expr_deriv = Dx(u(x)^3)
transformed_deriv = transform_power_ops(expr_deriv)
expanded = expand_derivatives(transformed_deriv)

# Should not crash and should produce a valid expression
@test !isnothing(expanded)
@test expanded isa Union{Num, SymbolicUtils.Term}

# Test non-integer exponents: should not be transformed
expr_nonint = u(x)^2.5
transformed_nonint = transform_power_ops(expr_nonint)
@test Symbolics.simplify(transformed_nonint - u(x)^2.5) == 0

# Test edge cases: u^0 = 1, u^1 = u
@test transform_power_ops(u(x)^1) == u(x)
@test transform_power_ops(u(x)^0) == 1
end

@testitem "GPU Device Detection" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
using ComponentArrays

# Test with nothing: should return false
@test should_apply_gpu_transform(nothing) == false

# Test with CPU parameters: should return false
cpu_params = ComponentArray(a = [1.0, 2.0, 3.0])
@test should_apply_gpu_transform(cpu_params) == false

# Test with GPU parameters (if CUDA available)
if CUDA.functional()
gpu_params = ComponentArray(a = [1.0, 2.0, 3.0]) |> gpud
@test should_apply_gpu_transform(gpu_params) == true
end
end

@testitem "Nonlinear PDE u^2 - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
import DomainSets: Interval

CUDA.functional() || return # Skip if CUDA not available

Random.seed!(100)

@parameters x
@variables u(..)
Dx = Differential(x)

# Simple nonlinear PDE: u^2 = 0 with boundary condition u(0) = 0
# This tests the symbolic transformation of power operations in PDE equations
eq = u(x)^2 ~ 0.0
bcs = [u(0.0) ~ 0.0]
domains = [x ∈ Interval(0.0, 1.0)]

# Neural network: small configuration for unit testing
inner = 10
chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1))

strategy = GridTraining(0.1)
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64

discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)])
prob = discretize(pde_system, discretization)

# Solve: power transformation should enable differentiation
res = solve(prob, Adam(0.01); maxiters = 200, callback = callback)

# Verify solution integrity: no NaN or Inf values
@test !any(isnan, res.u)
@test all(isfinite, res.u)
end

@testitem "Nonlinear PDE Dx(u^3) - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin
import DomainSets: Interval

CUDA.functional() || return # Skip if CUDA not available

Random.seed!(200)

@parameters x
@variables u(..)
Dx = Differential(x)

# Test case from issue #914: Dx(u^3)
# This case produced NaN in automatic differentiation prior to power operation expansion
# The fix transforms u^3 → u * u * u, enabling chain rule application
eq = Dx(u(x)^3) ~ 0.0
bcs = [u(0.0) ~ 0.0]
domains = [x ∈ Interval(0.0, 1.0)]

# Neural network: small configuration for unit testing
inner = 10
chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1))

strategy = QuasiRandomTraining(1000)
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64

discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)])
prob = discretize(pde_system, discretization)

# Solve: this was the case that generated NaN before the fix
res = solve(prob, Adam(0.01); maxiters = 200, callback = callback)

# Verify solution: the main assertion that the fix works
@test !any(isnan, res.u)
@test all(isfinite, res.u)
end