Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: detect observed variables and dependent parameters dependent on discrete parameters #3106

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ end
function has_observed_with_lhs(sys, sym)
has_observed(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.observed_syms)
return haskey(ic.observed_syms_to_timeseries, sym)
else
return any(isequal(sym), [eq.lhs for eq in observed(sys)])
end
Expand All @@ -740,7 +740,7 @@ end
function has_parameter_dependency_with_lhs(sys, sym)
has_parameter_dependencies(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.dependent_pars)
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
else
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
end
Expand All @@ -762,11 +762,20 @@ for traitT in [
allsyms = vars(sym; op = Symbolics.Operator)
for s in allsyms
s = unwrap(s)
if is_variable(sys, s) || is_independent_variable(sys, s) ||
has_observed_with_lhs(sys, s)
if is_variable(sys, s) || is_independent_variable(sys, s)
push!(ts_idxs, ContinuousTimeseries())
elseif is_timeseries_parameter(sys, s)
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
elseif is_time_dependent(sys) && iscall(s) && issym(operation(s)) &&
is_variable(sys, operation(s)(get_iv(sys)))
# DDEs case, to detect x(t - k)
push!(ts_idxs, ContinuousTimeseries())
elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing
union!(ts_idxs, ts)
elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing
union!(ts_idxs, ts)
end
end
end
end
Expand Down
11 changes: 10 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,16 @@ function build_explicit_observed_function(sys, ts;
ivs = independent_variables(sys)
dep_vars = scalarize(setdiff(vars, ivs))

obs = param_only ? Equation[] : observed(sys)
obs = observed(sys)
if param_only
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
obs = filter(obs) do eq
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
end
else
obs = Equation[]
end
end

cs = collect_constants(obs)
if !isempty(cs) > 0
Expand Down
87 changes: 58 additions & 29 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
const TunableIndexMap = Dict{BasicSymbolic,
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}

struct IndexCache
unknown_idx::UnknownIndexMap
Expand All @@ -48,8 +49,9 @@ struct IndexCache
tunable_idx::TunableIndexMap
constant_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms::Set{BasicSymbolic}
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
dependent_pars_to_timeseries::Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
Expand Down Expand Up @@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem)
end
end

observed_syms = Set{BasicSymbolic}()
for eq in observed(sys)
if symbolic_type(eq.lhs) != NotSymbolic()
sym = eq.lhs
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
push!(observed_syms, sym)
push!(observed_syms, ttsym)
push!(observed_syms, rsym)
push!(observed_syms, rttsym)
end
end

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
Expand Down Expand Up @@ -267,38 +255,79 @@ function IndexCache(sys::AbstractSystem)
end
end

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(nonnumeric_idxs),
observed_syms, independent_variable_symbols(sys)))
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = sym
end
end

dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()
dependent_pars_to_timeseries = Dict{
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()

for eq in parameter_dependencies(sys)
sym = eq.lhs
vs = vars(eq.rhs)
timeseries = TimeseriesSetType()
if is_time_dependent(sys)
for v in vs
if (idx = get(disc_idxs, v, nothing)) !== nothing
push!(timeseries, idx.clock_idx)
end
end
end
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
for s in [sym, ttsym, rsym, rttsym]
push!(dependent_pars, s)
for s in (sym, ttsym, rsym, rttsym)
dependent_pars_to_timeseries[s] = timeseries
if hasname(s) && (!iscall(s) || operation(s) != getindex)
symbol_to_variable[getname(s)] = sym
end
end
end

observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}()
for eq in observed(sys)
if symbolic_type(eq.lhs) != NotSymbolic()
sym = eq.lhs
vs = vars(eq.rhs)
timeseries = TimeseriesSetType()
if is_time_dependent(sys)
for v in vs
if (idx = get(disc_idxs, v, nothing)) !== nothing
push!(timeseries, idx.clock_idx)
elseif haskey(unk_idxs, v)
push!(timeseries, ContinuousTimeseries())
elseif haskey(observed_syms_to_timeseries, v)
union!(timeseries, observed_syms_to_timeseries[v])
elseif haskey(dependent_pars_to_timeseries, v)
union!(timeseries, dependent_pars_to_timeseries[v])
elseif iscall(v) && issym(operation(v)) &&
is_variable(sys, operation(v)(get_iv(sys)))
push!(timeseries, ContinuousTimeseries())
end
end
end
ttsym = default_toterm(sym)
rsym = renamespace(sys, sym)
rttsym = renamespace(sys, ttsym)
for s in (sym, ttsym, rsym, rttsym)
observed_syms_to_timeseries[s] = timeseries
end
end
end

for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
keys(const_idxs), keys(nonnumeric_idxs),
keys(observed_syms_to_timeseries), independent_variable_symbols(sys)))
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
symbol_to_variable[getname(sym)] = sym
end
end

return IndexCache(
unk_idxs,
disc_idxs,
callback_to_clocks,
tunable_idxs,
const_idxs,
nonnumeric_idxs,
observed_syms,
dependent_pars,
observed_syms_to_timeseries,
dependent_pars_to_timeseries,
disc_buffer_templates,
BufferTemplate(Real, tunable_buffer_size),
const_buffer_sizes,
Expand Down
11 changes: 11 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1467,3 +1467,14 @@ end
obsfn(buf, prob.u0, prob.p, 0.0)
@test buf ≈ [1.0, 1.0, 2.0]
end

# https://github.com/SciML/SciMLBase.jl/issues/786
@testset "Observed variables dependent on discrete parameters" begin
@variables x(t) obs(t)
@parameters c(t)
@mtkbuild sys = ODESystem(
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
sol = solve(prob, Tsit5())
@test sol[obs] ≈ 1:7
end
Comment on lines +1472 to +1480
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this MWE I provided in the issue is not a sufficient testcase. I'd propose the following

    @variables x1(t)=0 x2(t)=0 obs1(t) obs2(t)
    @parameters c1(t)=1 c2=1
    @mtkbuild sys = ODESystem(
        [D(x1) ~ c1,
        D(x2) ~ c2,
        obs1 ~ x1*c1,
        obs2 ~ x2*c2], t; discrete_events = [[1.0] => [c1 ~ 0, c2 ~ 0]])
    prob = ODEProblem(sys, [x1=>0, x2=>0], (0.0, 2))
    sol = solve(prob, Tsit5())

    # tests that should pass (?)
    @test sol([0,2], idxs=c1) == [1.0, 0.0]
    @test sol([0,0.9,1.1,2], idxs=obs1)  [0, 0.9, 0, 0]
    @test sol[obs1] == sol(sol.t, idxs=obs1) # errors because of mixed timeseries

    # the following tests check an observable which depends on a parameter which is not declared time dependent, which is done in the docs on discrete events
    # i don't know how this should be handled. Personally, as a user i'd expect all parameters to be discrete timeseries implicitly
    # Depending on your API design, those failures might be by design and don't need fixing.
    @test sol([0,2], idxs=c2) == [1.0, 0.0]
    @test sol([0,0.9,1.1,2], idxs=obs2)  [0, 0.9, 0, 0]
    @test sol[obs2] == sol(sol.t, idxs=obs2)

Loading