-
-
Notifications
You must be signed in to change notification settings - Fork 233
Fix #914: rewrite integer powers to multiplication on GPU #1021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
20c6860
cf1be53
955ad53
66be8b9
83c3502
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| module GPUUtils | ||
|
|
||
| using Symbolics, SymbolicUtils | ||
| using MLDataDevices: get_device, AbstractGPUDevice | ||
|
|
||
| export transform_power_ops, should_apply_gpu_transform | ||
|
|
||
| """ | ||
| transform_power_ops(expr) | ||
|
|
||
| Transform integer power operations into explicit multiplication chains | ||
| compatible with symbolic differentiation. | ||
|
|
||
| This function rewrites expressions of the form `u^n` (where `n` is a positive | ||
| integer) into equivalent multiplication expressions `u * u * ... * u` (n times). | ||
| This transformation enables automatic differentiation through the Symbolics.jl | ||
| chain rule without requiring special-cased derivative rules for power operations. | ||
|
|
||
| Example: | ||
| - `u^2` → `u * u` | ||
| - `u^3` → `u * u * u` | ||
| - `u^4` → `u * u * u * u` | ||
| """ | ||
| function transform_power_ops(expr) | ||
| count = Ref(0) | ||
|
|
||
| # Extract base expression from ModelingToolkit wrapper if present | ||
| was_num = expr isa Symbolics.Num | ||
| base_expr = was_num ? Symbolics.unwrap(expr) : expr | ||
|
|
||
| transformed = Symbolics.postwalk(base_expr) do node | ||
| # Process BasicSymbolic nodes (symbolic expressions in Symbolics v6+) | ||
| if node isa SymbolicUtils.BasicSymbolic | ||
| op = Symbolics.operation(node) | ||
| args = Symbolics.arguments(node) | ||
|
|
||
| # Match power operations | ||
| if op === ^ | ||
| base = args[1] | ||
| exponent = args[2] | ||
|
|
||
| # Transform only when exponent is a literal integer or integer-valued number | ||
| if exponent isa Integer || (exponent isa Number && exponent == floor(exponent)) | ||
| n = Int(exponent) | ||
| count[] += 1 | ||
|
|
||
| if n == 0 | ||
| return 1 | ||
| elseif n == 1 | ||
| return base | ||
| elseif n == 2 | ||
| # Use SymbolicUtils.term to prevent auto-simplification | ||
| return SymbolicUtils.term(*, base, base) | ||
| elseif n == 3 | ||
| return SymbolicUtils.term(*, base, base, base) | ||
| else | ||
| # Unroll arbitrary exponents: u^n → u * u * ... * u (n factors) | ||
| factors = [base for _ in 1:n] | ||
| return SymbolicUtils.term(*, factors...) | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| return node | ||
| end | ||
|
|
||
| # Debug logging | ||
| if count[] > 0 && get(ENV, "NEURALPDE_DEBUG", "0") == "1" | ||
| @info "GPU power transformation: expanded $(count[]) power operations to multiplication chains" | ||
| end | ||
|
|
||
| # Re-attach ModelingToolkit wrapper if the input was wrapped | ||
| return was_num ? Symbolics.Num(transformed) : transformed | ||
| end | ||
|
|
||
| """ | ||
| should_apply_gpu_transform(init_params) | ||
|
|
||
| Determine whether symbolic power operation transformation should be applied | ||
| based on the target computational device. | ||
|
|
||
| This function detects if `init_params` corresponds to GPU device parameters. | ||
| When GPU device is detected, power operations are expanded into multiplication | ||
| chains to enable efficient automatic differentiation on GPU accelerators. | ||
|
|
||
| Arguments: | ||
| - `init_params`: Model initialization parameters, typically from a Lux neural network | ||
|
|
||
| Returns: | ||
| - `true` if parameters are allocated on a GPU device | ||
| - `false` otherwise, or if `init_params` is `nothing` | ||
| """ | ||
| function should_apply_gpu_transform(init_params) | ||
| init_params === nothing && return false | ||
|
|
||
| # Allow explicit override via environment variable for development and testing | ||
| if get(ENV, "NEURALPDE_GPU_POWER_REWRITE", "0") == "1" | ||
| return true | ||
| end | ||
|
|
||
| # Detect GPU devices using the MLDataDevices.jl abstraction | ||
| try | ||
| return get_device(init_params) isa AbstractGPUDevice | ||
| catch | ||
| # If device detection fails, default to CPU mode (no transformation) | ||
| return false | ||
| end | ||
| end | ||
|
|
||
| end # module GPUUtils | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't in the runtests so it won't be ran.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dr Chris, Thank you for pointing this out. I might be missing something here. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| @testsetup module GPUNonlinearTestSetup | ||
| using NeuralPDE | ||
| using Symbolics: expand_derivatives | ||
| using Lux, Optimization, OptimizationOptimisers | ||
| using Random, ComponentArrays, LuxCUDA, CUDA | ||
| using NeuralPDE.GPUUtils | ||
|
|
||
| function callback(p, l) | ||
| if p.iter == 1 || p.iter % 100 == 0 | ||
| println("GPU Nonlinear Test - Loss: $l after $(p.iter) iterations") | ||
| end | ||
| return false | ||
| end | ||
|
|
||
| const gpud = CUDA.functional() ? gpu_device() : nothing | ||
| export gpud, callback, transform_power_ops, should_apply_gpu_transform | ||
| end | ||
|
|
||
| @testitem "Symbolic Power Transformation" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin | ||
| using Symbolics | ||
|
|
||
| @variables x u(..) | ||
| Dx = Differential(x) | ||
|
|
||
| # Test basic transformation: u^2 → u * u | ||
| expr2 = u(x)^2 | ||
| transformed2 = transform_power_ops(expr2) | ||
| @test Symbolics.simplify(transformed2 - u(x)*u(x)) == 0 | ||
|
|
||
| # Test: u^3 → u * u * u | ||
| expr3 = u(x)^3 | ||
| transformed3 = transform_power_ops(expr3) | ||
| @test Symbolics.simplify(transformed3 - u(x)*u(x)*u(x)) == 0 | ||
|
|
||
| # Test derivative compatibility: symbolic differentiation should work after transformation | ||
| expr_deriv = Dx(u(x)^3) | ||
| transformed_deriv = transform_power_ops(expr_deriv) | ||
| expanded = expand_derivatives(transformed_deriv) | ||
|
|
||
| # Should not crash and should produce a valid expression | ||
| @test !isnothing(expanded) | ||
| @test expanded isa Union{Num, SymbolicUtils.Term} | ||
|
|
||
| # Test non-integer exponents: should not be transformed | ||
| expr_nonint = u(x)^2.5 | ||
| transformed_nonint = transform_power_ops(expr_nonint) | ||
| @test Symbolics.simplify(transformed_nonint - u(x)^2.5) == 0 | ||
|
|
||
| # Test edge cases: u^0 = 1, u^1 = u | ||
| @test transform_power_ops(u(x)^1) == u(x) | ||
| @test transform_power_ops(u(x)^0) == 1 | ||
| end | ||
|
|
||
| @testitem "GPU Device Detection" tags = [:gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin | ||
| using ComponentArrays | ||
|
|
||
| # Test with nothing: should return false | ||
| @test should_apply_gpu_transform(nothing) == false | ||
|
|
||
| # Test with CPU parameters: should return false | ||
| cpu_params = ComponentArray(a = [1.0, 2.0, 3.0]) | ||
| @test should_apply_gpu_transform(cpu_params) == false | ||
|
|
||
| # Test with GPU parameters (if CUDA available) | ||
| if CUDA.functional() | ||
| gpu_params = ComponentArray(a = [1.0, 2.0, 3.0]) |> gpud | ||
| @test should_apply_gpu_transform(gpu_params) == true | ||
| end | ||
| end | ||
|
|
||
| @testitem "Nonlinear PDE u^2 - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin | ||
| import DomainSets: Interval | ||
|
|
||
| CUDA.functional() || return # Skip if CUDA not available | ||
|
|
||
| Random.seed!(100) | ||
|
|
||
| @parameters x | ||
| @variables u(..) | ||
| Dx = Differential(x) | ||
|
|
||
| # Simple nonlinear PDE: u^2 = 0 with boundary condition u(0) = 0 | ||
| # This tests the symbolic transformation of power operations in PDE equations | ||
| eq = u(x)^2 ~ 0.0 | ||
| bcs = [u(0.0) ~ 0.0] | ||
| domains = [x ∈ Interval(0.0, 1.0)] | ||
|
|
||
| # Neural network: small configuration for unit testing | ||
| inner = 10 | ||
| chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1)) | ||
|
|
||
| strategy = GridTraining(0.1) | ||
| ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64 | ||
|
|
||
| discretization = PhysicsInformedNN(chain, strategy; init_params = ps) | ||
|
|
||
| @named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)]) | ||
| prob = discretize(pde_system, discretization) | ||
|
|
||
| # Solve: power transformation should enable differentiation | ||
| res = solve(prob, Adam(0.01); maxiters = 200, callback = callback) | ||
|
|
||
| # Verify solution integrity: no NaN or Inf values | ||
| @test !any(isnan, res.u) | ||
| @test all(isfinite, res.u) | ||
| end | ||
|
|
||
| @testitem "Nonlinear PDE Dx(u^3) - CUDA" tags = [:cuda, :gpu_nonlinear] setup = [GPUNonlinearTestSetup] begin | ||
| import DomainSets: Interval | ||
|
|
||
| CUDA.functional() || return # Skip if CUDA not available | ||
|
|
||
| Random.seed!(200) | ||
|
|
||
| @parameters x | ||
| @variables u(..) | ||
| Dx = Differential(x) | ||
|
|
||
| # Test case from issue #914: Dx(u^3) | ||
| # This case produced NaN in automatic differentiation prior to power operation expansion | ||
| # The fix transforms u^3 → u * u * u, enabling chain rule application | ||
| eq = Dx(u(x)^3) ~ 0.0 | ||
| bcs = [u(0.0) ~ 0.0] | ||
| domains = [x ∈ Interval(0.0, 1.0)] | ||
|
|
||
| # Neural network: small configuration for unit testing | ||
| inner = 10 | ||
| chain = Chain(Dense(1, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1)) | ||
|
|
||
| strategy = QuasiRandomTraining(1000) | ||
| ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f64 | ||
|
|
||
| discretization = PhysicsInformedNN(chain, strategy; init_params = ps) | ||
|
|
||
| @named pde_system = PDESystem(eq, bcs, domains, [x], [u(x)]) | ||
| prob = discretize(pde_system, discretization) | ||
|
|
||
| # Solve: this was the case that generated NaN before the fix | ||
| res = solve(prob, Adam(0.01); maxiters = 200, callback = callback) | ||
|
|
||
| # Verify solution: the main assertion that the fix works | ||
| @test !any(isnan, res.u) | ||
| @test all(isfinite, res.u) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem like it would generate more efficient code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You’re right. yes
I did that to avoid NaNs we were seeing on GPU backward passes for expressions like u(x)^2 and Dx(u(x)^3) (issue #914).
I'll update the comment to make the intent clearer, but would you prefer to handle this error a different way?