Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions lib/ModelingToolkitBase/src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ function (xn::Num)(k::ShiftIndex)
vars = Set{SymbolicT}()
SU.search_variables!(vars, x)
if length(vars) != 1
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
error(lazy"Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
end
var = only(vars)
if operation(var) === getindex
var = arguments(var)[1]
end
if !iscall(var)
throw(ArgumentError("Cannot shift time-independent variable $var"))
throw(ArgumentError(lazy"Cannot shift time-independent variable $var"))
end
if length(arguments(var)) != 1
error("Cannot shift an expression with multiple independent variables $x.")
error(lazy"Cannot shift an expression with multiple independent variables $x.")
end
t = only(arguments(var))

Expand All @@ -154,14 +154,14 @@ function (xn::Symbolics.Arr)(k::ShiftIndex)
vars = Set{SymbolicT}()
SU.search_variables!(vars, x)
if length(vars) != 1
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
error(lazy"Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
end
var = only(vars)
if !iscall(var)
throw(ArgumentError("Cannot shift time-independent variable $var"))
throw(ArgumentError(lazy"Cannot shift time-independent variable $var"))
end
if length(arguments(var)) != 1
error("Cannot shift an expression with multiple independent variables $x.")
error(lazy"Cannot shift an expression with multiple independent variables $x.")
end
t = only(arguments(var))

Expand Down
2 changes: 1 addition & 1 deletion lib/ModelingToolkitBase/src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ tovar(s::Union{Num, Symbolics.Arr}) = wrap(tovar(unwrap(s)))
function toparam_validate(s::SymbolicT)
if iscall(s)
error(
"""
lazy"""
`@parameters` cannot create time-dependent parameters. Encountered $s. Use \
`@discretes` for this purpose.
"""
Expand Down
14 changes: 7 additions & 7 deletions lib/ModelingToolkitBase/src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function generate_ODENLStepData(sys, u0, p, mm, nlstep_compile, nlstep_scc)
)
end

@fallback_iip_specialize function SciMLBase.ODEFunction{iip, spec}(
sys::System; u0 = nothing, p = nothing, tgrad = false, jac = false,
Base.@nospecializeinfer @fallback_iip_specialize function SciMLBase.ODEFunction{iip, spec}(
sys::System; @nospecialize(u0 = nothing), @nospecialize(p = nothing), tgrad = false, jac = false,
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
simplify = false, cse = true, initialization_data = nothing, expression = Val{false},
steady_state = false, checkbounds = false, sparsity = false, @nospecialize(analytic = nothing),
simplify = false, cse = true, @nospecialize(initialization_data = nothing), expression = Val{false},
check_compatibility = true, nlstep = false, nlstep_compile = true, nlstep_scc = false,
kwargs...
) where {iip, spec}
Expand Down Expand Up @@ -93,9 +93,9 @@ end
maybe_codegen_scimlfn(expression, ODEFunction{iip, spec}, args; kwargs...)
end

@fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}(
sys::System, op, tspan;
callback = nothing, check_length = true, eval_expression = false,
Base.@nospecializeinfer @fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}(
sys::System, @nospecialize(op), tspan;
@nospecialize(callback = nothing), check_length = true, eval_expression = false,
expression = Val{false}, eval_module = @__MODULE__, check_compatibility = true,
kwargs...
) where {iip, spec}
Expand Down
27 changes: 17 additions & 10 deletions lib/ModelingToolkitBase/src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,9 @@ end

Returns a function `condition(u,t,integrator)`, condition(out,u,t,integrator)` returning the `condition(cb)`.
"""
function compile_condition(
cbs::Union{AbstractCallback, Vector{<:AbstractCallback}}, sys, dvs, ps;
Base.@nospecializeinfer function compile_condition(
@nospecialize(cbs::Union{AbstractCallback, Vector{<:AbstractCallback}}),
sys, @nospecialize(dvs), @nospecialize(ps);
eval_expression = false, eval_module = @__MODULE__, kwargs...
)
u = map(value, dvs)
Expand Down Expand Up @@ -760,10 +761,11 @@ function generate_continuous_callbacks(
push!(_cbs, cb)
end
sort!(OrderedDict(cb_classes), by = cb -> cb[1])
compiled_callbacks = [
generate_callback(cb, sys; kwargs...)
for ((rf, reinit), cb) in cb_classes
]
# Use explicit loop to avoid Generator type inference overhead
compiled_callbacks = Vector{Any}(undef, length(cb_classes))
for (i, ((rf, reinit), cb)) in enumerate(cb_classes)
compiled_callbacks[i] = generate_callback(cb, sys; kwargs...)
end
if length(compiled_callbacks) == 1
return only(compiled_callbacks)
else
Expand All @@ -777,7 +779,12 @@ function generate_discrete_callbacks(
)
dbs = discrete_events(sys)
isempty(dbs) && return nothing
return [generate_callback(db, sys; kwargs...) for db in dbs]
# Use explicit loop to avoid Generator type inference overhead
result = Vector{Any}(undef, length(dbs))
for (i, db) in enumerate(dbs)
result[i] = generate_callback(db, sys; kwargs...)
end
return result
end

EMPTY_AFFECT(args...) = nothing
Expand Down Expand Up @@ -985,9 +992,9 @@ end
"""
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
"""
function compile_equational_affect(
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false,
eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs...
Base.@nospecializeinfer function compile_equational_affect(
@nospecialize(aff::Union{AffectSystem, Vector{Equation}}), sys; reset_jumps = false,
eval_expression = false, eval_module = @__MODULE__, @nospecialize(op = nothing), kwargs...
)
if aff isa AbstractVector
aff = make_affect(aff; iv = get_iv(sys))
Expand Down
4 changes: 2 additions & 2 deletions lib/ModelingToolkitBase/src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ generated functions, and `args` are the arguments.

All other keyword arguments are forwarded to `build_function`.
"""
function build_function_wrapper(
sys::AbstractSystem, expr, args...; p_start = 2,
Base.@nospecializeinfer function build_function_wrapper(
sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2,
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args),
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
add_observed = true, filter_observed = Returns(true),
Expand Down
8 changes: 4 additions & 4 deletions lib/ModelingToolkitBase/src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function connector_type(sys::AbstractSystem)
for s in unkvars
vtype = get_connection_type(s)
if vtype === Stream
isarray(s) && error("Array stream variables are not supported. Got $s.")
isarray(s) && error(lazy"Array stream variables are not supported. Got $s.")
n_stream += 1
elseif vtype === Flow
n_flow += 1
Expand Down Expand Up @@ -227,10 +227,10 @@ function validate_causal_variables_connection(allvars::Vector{SymbolicT})
for var in allvars
vtype = getvariabletype(var)
vtype === VARIABLE ||
throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
throw(ArgumentError(lazy"Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
end
if !allunique(allvars)
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
throw(ArgumentError(lazy"Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
end
sh1 = SU.shape(allvars[1])::SU.ShapeVecT
sz1 = SU.SmallV{Int}()
Expand All @@ -245,7 +245,7 @@ function validate_causal_variables_connection(allvars::Vector{SymbolicT})
push!(sz2, length(x))
end
if !isequal(sz1, sz2)
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively."))
throw(ArgumentError(lazy"Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively."))

end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ModelingToolkitBase/src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ function compile_functional_affect(
for oexpr in obs_exprs
invalid_vars = invalid_variables(sys, oexpr)
if length(invalid_vars) > 0
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
error(lazy"Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
end
end
end
Expand All @@ -253,7 +253,7 @@ function compile_functional_affect(
end
invalid_vars = unassignable_variables(sys, mexpr)
if length(invalid_vars) > 0
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
error(lazy"Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion lib/ModelingToolkitBase/src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function MTKParameters(
end
end
if !SU.isconst(val)
error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
error(lazy"Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
end
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
Expand Down
55 changes: 42 additions & 13 deletions lib/ModelingToolkitBase/src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ function get_temporary_value(p, floatT = Float64)
elseif stype <: AbstractArray
zeros(eltype(stype), size(p))
else
error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization")
error(lazy"Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization")
end
end

Expand Down Expand Up @@ -1727,6 +1727,16 @@ Macro for writing problem/function constructors. Expects a function definition w
parameters for `iip` and `specialize`. Generates fallbacks with
`specialize = SciMLBase.FullSpecialize` and `iip = true`.
"""
# Unwrap `@nospecialize(arg)` to get the underlying argument expression.
# Returns the argument unchanged if not wrapped in @nospecialize.
function _unwrap_nospecialize(arg)
if Meta.isexpr(arg, :macrocall) && length(arg.args) >= 3 &&
arg.args[1] in (Symbol("@nospecialize"), GlobalRef(Base, Symbol("@nospecialize")))
return arg.args[3]
end
return arg
end

macro fallback_iip_specialize(ex)
@assert Meta.isexpr(ex, :function)
# fnname is ODEProblem{iip, spec}(args...) where {iip, spec}
Expand All @@ -1745,8 +1755,27 @@ macro fallback_iip_specialize(ex)
# the function should have keyword arguments
@assert Meta.isexpr(args[1], :parameters)

# arguments to call with
call_args = map(args) do arg
# Create signature args with @nospecialize stripped (for fallback function signatures)
sig_args = map(args) do arg
unwrapped = _unwrap_nospecialize(arg)
# Handle :parameters specially - unwrap each kwarg inside
if Meta.isexpr(unwrapped, :parameters)
new_params = map(unwrapped.args) do kwarg
kw = _unwrap_nospecialize(kwarg)
# Convert :(=) to :kw if needed
if Meta.isexpr(kw, :(=))
Expr(:kw, kw.args...)
else
kw
end
end
return Expr(:parameters, new_params...)
end
return unwrapped
end

# arguments to call with (for forwarding calls)
call_args = map(sig_args) do arg
# keyword args are in `Expr(:parameters)` so any `Expr(:kw)` here
# are optional positional arguments. Analyze `:(f(a, b = 1; k = 1, l...))`
# to understand
Expand All @@ -1755,7 +1784,7 @@ macro fallback_iip_specialize(ex)
end
call_kwargs = map(call_args[1].args) do arg
Meta.isexpr(arg, :...) && return arg
@assert Meta.isexpr(arg, :kw)
@assert Meta.isexpr(arg, :kw) "Expected keyword argument, got $(arg)"
return Expr(:kw, arg.args[1], arg.args[1])
end
call_args[1] = Expr(:parameters, call_kwargs...)
Expand All @@ -1772,25 +1801,25 @@ macro fallback_iip_specialize(ex)
)
# `ODEProblem{iip}`
fnname_iip = Expr(:curly, fnname_name, curly_args[1])
# `ODEProblem{iip}(args...)`
fncall_iip = Expr(:call, fnname_iip, args...)
# ODEProblem{iip}(args...) where {iip}
# `ODEProblem{iip}(sig_args...)` - use sig_args (no @nospecialize) for fallback signature
fncall_iip = Expr(:call, fnname_iip, sig_args...)
# ODEProblem{iip}(sig_args...) where {iip}
fnwhere_iip = Expr(:where, fncall_iip, where_args[1])
fn_iip = Expr(:function, fnwhere_iip, callexpr_iip)

# `ODEProblem{true}(call_args...)`
callexpr_base = Expr(:call, Expr(:curly, fnname_name, true), call_args...)
# `ODEProblem(args...)`
fncall_base = Expr(:call, fnname_name, args...)
# `ODEProblem(sig_args...)` - use sig_args for fallback signature
fncall_base = Expr(:call, fnname_name, sig_args...)
fn_base = Expr(:function, fncall_base, callexpr_base)

# Handle case when this is a problem constructor and `u0map` is a `StaticArray`,
# where `iip` should default to `false`.
fn_sarr = nothing
if occursin("Problem", string(fnname_name))
# args should at least contain an argument for the `u0map`
@assert length(args) > 2
u0_arg = args[3]
# sig_args should at least contain an argument for the `u0map`
@assert length(sig_args) > 2
u0_arg = sig_args[3]
# should not have a type-annotation
@assert !Meta.isexpr(u0_arg, :(::))
if Meta.isexpr(u0_arg, :kw)
Expand All @@ -1801,7 +1830,7 @@ macro fallback_iip_specialize(ex)
end

callexpr_sarr = Expr(:call, Expr(:curly, fnname_name, false), call_args...)
fncall_sarr = Expr(:call, fnname_name, args[1], args[2], u0_arg, args[4:end]...)
fncall_sarr = Expr(:call, fnname_name, sig_args[1], sig_args[2], u0_arg, sig_args[4:end]...)
fn_sarr = Expr(:function, fncall_sarr, callexpr_sarr)
end
return quote
Expand Down
10 changes: 5 additions & 5 deletions lib/ModelingToolkitBase/src/systems/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...)
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
for dv in noisedvs
dv ∈ allunknowns ||
throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
throw(ArgumentError(lazy"Variable $dv in noise equations is not an unknown of the system."))
end
end

Expand Down Expand Up @@ -905,20 +905,20 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
if !iscall(var)
SU.query(isequal(iv), var) && (
var ∈ sts ||
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))
throw(ArgumentError(lazy"Time-dependent variable $var is not an unknown of the system."))
)
elseif length(arguments(var)) > 1
throw(ArgumentError("Too many arguments for variable $var."))
throw(ArgumentError(lazy"Too many arguments for variable $var."))
elseif length(arguments(var)) == 1
if iscall(var) && operation(var) isa Differential
var = only(arguments(var))
end
arg = only(arguments(var))
operation(var)(iv) ∈ sts ||
throw(ArgumentError("Variable $var is not a variable of the System. Called variables must be variables of the System."))
throw(ArgumentError(lazy"Variable $var is not a variable of the System. Called variables must be variables of the System."))

isequal(arg, iv) || isparameter(arg) || isconst(arg) && symtype(arg) <: Real ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
throw(ArgumentError(lazy"Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && !isequal(arg, iv) && push!(auxps, arg)
else
Expand Down
Loading
Loading