Skip to content

Conversation

@ChrisRackauckas-Claude
Copy link
Contributor

Summary

  • Implements reverse-mode AD rules (rrule!!) for SciMLBase.tmap and SciMLBase.responsible_map functions
  • Enables Mooncake to differentiate through ensemble solves
  • Uses Mooncake's fdata system for vector gradients (tangent field of CoDual)
  • Prepares pullback caches during forward pass for proper nested AD

Implementation Details

  • Forward pass computes primals and prepares pullback caches for each element
  • Reverse pass reads gradients from output fdata, computes input gradients via caches
  • responsible_map applies pullbacks in reverse order (for correctness with stateful functions)
  • Helper _accum_tangents function handles tangent accumulation for various types

Test Plan

  • Extension compiles successfully
  • Simple gradient computation with tmap: sum(map(x->x^2, xs)) produces correct gradients [2.0, 4.0, 6.0] for xs = [1.0, 2.0, 3.0]
  • responsible_map produces same correct gradients
  • Multi-argument case: sum(map((x,y)->x^2+y, xs, ys)) produces correct gradients for both inputs
  • SciMLBase tests pass
  • Downstream ensemble AD tests (to be validated after merge)

Related Issues

Closes SciML/DiffEqBase.jl#1256

🤖 Generated with Claude Code

This implements reverse-mode AD rules for SciMLBase.tmap and
SciMLBase.responsible_map functions, enabling Mooncake to
differentiate through ensemble solves.

Key implementation details:
- Uses Mooncake's fdata system for vector gradients (tangent field of CoDual)
- Prepares pullback caches during forward pass for nested AD
- Applies pullbacks in reverse order for responsible_map (for stateful f)

Closes SciML/DiffEqBase.jl#1256

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@ChrisRackauckas-Claude
Copy link
Contributor Author

Update on Test Results

✅ tmap/responsible_map rules work correctly in isolation:

using SciMLBase, Mooncake

f(x) = x^2
xs = [1.0, 2.0, 3.0]

function loss(xs)
    ys = SciMLBase.tmap(f, xs)
    return sum(ys)
end

cache = Mooncake.prepare_gradient_cache(loss, xs)
val, grad = Mooncake.value_and_gradient!!(cache, loss, xs)
# Value: 14.0
# Gradient: [2.0, 4.0, 6.0] ✓

⚠️ Full ensemble solve still fails:
The downstream ensemble AD test still encounters a StackOverflowError during Mooncake's rule compilation for __solve with EnsembleProblem. The stack overflow occurs before reaching tmap/responsible_map - Mooncake gets stuck in infinite recursion while compiling rules for the ensemble solve machinery.

Root cause analysis:
The issue appears to be that Mooncake's rule compilation for __solve(::EnsembleProblem, ...) enters infinite recursion. The tmap/responsible_map rules here are necessary but not sufficient - additional work is needed to either:

  1. Mark __solve for EnsembleProblem as a primitive with a custom rule
  2. Or break the compilation cycle some other way

This PR provides the foundational tmap/responsible_map rules that will be needed once the higher-level issue is resolved.

@ChrisRackauckas-Claude
Copy link
Contributor Author

MWE for the remaining issue

using OrdinaryDiffEq
using SciMLSensitivity
using DifferentiationInterface
using ADTypes: AutoMooncake
using Mooncake

function fiip(du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)

N = 3
eu0 = rand(N, 2)
ep = rand(N, 4)

function sum_of_e_solution(p)
    ensemble_prob = EnsembleProblem(
        prob,
        prob_func = (prob, i, repeat) -> remake(prob, u0 = eu0[i, :], p = p[i, :], saveat = 0.1)
    )
    sol = solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories = N)
    return sum(Array(sol.u[1]))
end

# This works:
sum_of_e_solution(ep)

# This fails with StackOverflowError during rule compilation:
DifferentiationInterface.gradient(sum_of_e_solution, AutoMooncake(; config = nothing), ep)

Error:

Warning: detected a stack overflow; program state may be corrupted
Mooncake.MooncakeRuleCompilationError(...Tuple{SciMLBase.var"##__solve#799", ..., EnsembleProblem{...}, Tsit5{...}, EnsembleSerial}...)

The stack overflow occurs in Mooncake's rule compilation for __solve with EnsembleProblem, before it reaches tmap/responsible_map.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mooncake AD backend limitations with MTK and Ensemble problems

2 participants