Skip to content

Commit

Permalink
fix: better handling of (possibly scalarized) array parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Aug 1, 2024
1 parent a50f143 commit 6af4c99
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 24 deletions.
76 changes: 52 additions & 24 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,42 @@ function wrap_array_vars(
# tunables are scalarized and concatenated, so we need to have assignments
# for the non-scalarized versions
array_tunables = Dict{Any, AbstractArray{Int}}()
for p in ps
idx = parameter_index(sys, p)
idx isa ParameterIndex || continue
idx.portion isa SciMLStructures.Tunable || continue
idx.idx isa AbstractArray || continue
array_tunables[p] = idx.idx
end
# Other parameters may be scalarized arrays but used in the vector form
other_array_parameters = Assignment[]
other_array_parameters = Dict{Any, Any}()

for p in ps
p = unwrap(p)
if iscall(p) && operation(p) == getindex
p = arguments(p)[1]
end
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
scal = collect(p)
# all scalarized variables are in `ps`
all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_tunables, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
if Symbolics.isarraysymbolic(p)
idx === nothing || continue
push!(other_array_parameters, p collect(p))
elseif iscall(p) && operation(p) == getindex
idx === nothing && continue
# all of the scalarized variables are in `ps`
all(x -> any(isequal(x), ps), collect(p))|| continue
push!(other_array_parameters, p collect(p))
if idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if all(x -> x isa ParameterIndex && x.portion isa SciMLStructures.Tunable, idxs)
idxs = map(x -> x.idx, idxs)
end
if all(x -> x isa Int, idxs)
if vec(idxs) == idxs[begin]:idxs[end]
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
idxs = reshape(idxs[begin]:-1:idxs[end], size(idxs))
end
array_tunables[p] = idxs
else
other_array_parameters[p] = scal
end
elseif idx isa Int
continue
elseif idx.portion != SciMLStructures.Tunable()
other_array_parameters[p] = scal
else
array_tunables[p] = idx.idx
end
end
for (k, inds) in array_vars
Expand All @@ -278,7 +295,8 @@ function wrap_array_vars(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_tunables],
other_array_parameters
[k Code.MakeArray(v, typeof(v))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand All @@ -294,7 +312,9 @@ function wrap_array_vars(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_tunables]
for (k, v) in array_tunables],
[k Code.MakeArray(v, typeof(v))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand All @@ -310,7 +330,9 @@ function wrap_array_vars(
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(view($(expr.args[uind + 2].name), $v))
for (k, v) in array_tunables]
for (k, v) in array_tunables],
[k Code.MakeArray(v, typeof(v))
for (k, v) in other_array_parameters]
),
expr.body,
false
Expand Down Expand Up @@ -499,15 +521,18 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
return unwrap(sym) in 1:length(parameter_symbols(sys))
end
return any(isequal(sym), parameter_symbols(sys)) ||
hasname(sym) && is_parameter(sys, getname(sym))
hasname(sym) && !(iscall(sym) && operation(sym) == getindex) &&
is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return is_parameter(ic, sym)
end

named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)]
named_parameters = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
return any(isequal(sym), named_parameters) ||
count(NAMESPACE_SEPARATOR, string(sym)) == 1 &&
count(isequal(sym),
Expand Down Expand Up @@ -543,7 +568,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
return sym
end
idx = findfirst(isequal(sym), parameter_symbols(sys))
if idx === nothing && hasname(sym)
if idx === nothing && hasname(sym) && !(iscall(sym) && operation(sym) == getindex)
idx = parameter_index(sys, getname(sym))
end
return idx
Expand All @@ -559,13 +584,16 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
return idx
end
end
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
pnames = [getname(x)
for x in parameter_symbols(sys)
if hasname(x) && !(iscall(x) && operation(x) == getindex)]
idx = findfirst(isequal(sym), pnames)
if idx !== nothing
return idx
elseif count(NAMESPACE_SEPARATOR, string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys))))
nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, pnames))
end
return nothing
end
Expand Down
14 changes: 14 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1250,3 +1250,17 @@ end
prob = ODEProblem(ssys, [], (0.0, 1.0), [])
@test prob[x] == prob[y] == prob[z] == 1.0
end

@testset "Scalarized parameters in array functions" begin
@variables u(t)[1:2] x(t)[1:2] o(t)[1:2]
@parameters p[1:2, 1:2] [tunable = false]
@named sys = ODESystem(
[D(u) ~ (sum(u) + sum(x) + sum(p) + sum(o)) * x, o ~ prod(u) * x],
t, [u..., x..., o...], [p...])
sys1, = structural_simplify(sys, ([x...], []))
fn1, = ModelingToolkit.generate_function(sys1; expression = Val{false})
@test_nowarn fn1(ones(4), 2ones(2), 3ones(2, 2), 4.0)
sys2, = structural_simplify(sys, ([x...], []); split = false)
fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false})
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
end

0 comments on commit 6af4c99

Please sign in to comment.