diff --git a/lib/ModelingToolkitBase/src/discretedomain.jl b/lib/ModelingToolkitBase/src/discretedomain.jl index 5b55f77da3..71225ff84c 100644 --- a/lib/ModelingToolkitBase/src/discretedomain.jl +++ b/lib/ModelingToolkitBase/src/discretedomain.jl @@ -121,17 +121,17 @@ function (xn::Num)(k::ShiftIndex) vars = Set{SymbolicT}() SU.search_variables!(vars, x) if length(vars) != 1 - error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") + error(lazy"Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") end var = only(vars) if operation(var) === getindex var = arguments(var)[1] end if !iscall(var) - throw(ArgumentError("Cannot shift time-independent variable $var")) + throw(ArgumentError(lazy"Cannot shift time-independent variable $var")) end if length(arguments(var)) != 1 - error("Cannot shift an expression with multiple independent variables $x.") + error(lazy"Cannot shift an expression with multiple independent variables $x.") end t = only(arguments(var)) @@ -154,14 +154,14 @@ function (xn::Symbolics.Arr)(k::ShiftIndex) vars = Set{SymbolicT}() SU.search_variables!(vars, x) if length(vars) != 1 - error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") + error(lazy"Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") end var = only(vars) if !iscall(var) - throw(ArgumentError("Cannot shift time-independent variable $var")) + throw(ArgumentError(lazy"Cannot shift time-independent variable $var")) end if length(arguments(var)) != 1 - error("Cannot shift an expression with multiple independent variables $x.") + error(lazy"Cannot shift an expression with multiple independent variables $x.") end t = only(arguments(var)) diff --git a/lib/ModelingToolkitBase/src/parameters.jl b/lib/ModelingToolkitBase/src/parameters.jl index bded758767..434a96c076 100644 --- a/lib/ModelingToolkitBase/src/parameters.jl +++ b/lib/ModelingToolkitBase/src/parameters.jl @@ -67,7 +67,7 @@ tovar(s::Union{Num, Symbolics.Arr}) = wrap(tovar(unwrap(s))) function toparam_validate(s::SymbolicT) if iscall(s) error( - """ + lazy""" `@parameters` cannot create time-dependent parameters. Encountered $s. Use \ `@discretes` for this purpose. """ diff --git a/lib/ModelingToolkitBase/src/problems/odeproblem.jl b/lib/ModelingToolkitBase/src/problems/odeproblem.jl index d33bc04d37..55e87643df 100644 --- a/lib/ModelingToolkitBase/src/problems/odeproblem.jl +++ b/lib/ModelingToolkitBase/src/problems/odeproblem.jl @@ -14,11 +14,11 @@ function generate_ODENLStepData(sys, u0, p, mm, nlstep_compile, nlstep_scc) ) end -@fallback_iip_specialize function SciMLBase.ODEFunction{iip, spec}( - sys::System; u0 = nothing, p = nothing, tgrad = false, jac = false, +Base.@nospecializeinfer @fallback_iip_specialize function SciMLBase.ODEFunction{iip, spec}( + sys::System; @nospecialize(u0 = nothing), @nospecialize(p = nothing), tgrad = false, jac = false, t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false, - steady_state = false, checkbounds = false, sparsity = false, analytic = nothing, - simplify = false, cse = true, initialization_data = nothing, expression = Val{false}, + steady_state = false, checkbounds = false, sparsity = false, @nospecialize(analytic = nothing), + simplify = false, cse = true, @nospecialize(initialization_data = nothing), expression = Val{false}, check_compatibility = true, nlstep = false, nlstep_compile = true, nlstep_scc = false, kwargs... ) where {iip, spec} @@ -93,9 +93,9 @@ end maybe_codegen_scimlfn(expression, ODEFunction{iip, spec}, args; kwargs...) end -@fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}( - sys::System, op, tspan; - callback = nothing, check_length = true, eval_expression = false, +Base.@nospecializeinfer @fallback_iip_specialize function SciMLBase.ODEProblem{iip, spec}( + sys::System, @nospecialize(op), tspan; + @nospecialize(callback = nothing), check_length = true, eval_expression = false, expression = Val{false}, eval_module = @__MODULE__, check_compatibility = true, kwargs... ) where {iip, spec} diff --git a/lib/ModelingToolkitBase/src/systems/callbacks.jl b/lib/ModelingToolkitBase/src/systems/callbacks.jl index e70a62af62..65a549fac0 100644 --- a/lib/ModelingToolkitBase/src/systems/callbacks.jl +++ b/lib/ModelingToolkitBase/src/systems/callbacks.jl @@ -710,8 +710,9 @@ end Returns a function `condition(u,t,integrator)`, condition(out,u,t,integrator)` returning the `condition(cb)`. """ -function compile_condition( - cbs::Union{AbstractCallback, Vector{<:AbstractCallback}}, sys, dvs, ps; +Base.@nospecializeinfer function compile_condition( + @nospecialize(cbs::Union{AbstractCallback, Vector{<:AbstractCallback}}), + sys, @nospecialize(dvs), @nospecialize(ps); eval_expression = false, eval_module = @__MODULE__, kwargs... ) u = map(value, dvs) @@ -760,10 +761,11 @@ function generate_continuous_callbacks( push!(_cbs, cb) end sort!(OrderedDict(cb_classes), by = cb -> cb[1]) - compiled_callbacks = [ - generate_callback(cb, sys; kwargs...) - for ((rf, reinit), cb) in cb_classes - ] + # Use explicit loop to avoid Generator type inference overhead + compiled_callbacks = Vector{Any}(undef, length(cb_classes)) + for (i, ((rf, reinit), cb)) in enumerate(cb_classes) + compiled_callbacks[i] = generate_callback(cb, sys; kwargs...) + end if length(compiled_callbacks) == 1 return only(compiled_callbacks) else @@ -777,7 +779,12 @@ function generate_discrete_callbacks( ) dbs = discrete_events(sys) isempty(dbs) && return nothing - return [generate_callback(db, sys; kwargs...) for db in dbs] + # Use explicit loop to avoid Generator type inference overhead + result = Vector{Any}(undef, length(dbs)) + for (i, db) in enumerate(dbs) + result[i] = generate_callback(db, sys; kwargs...) + end + return result end EMPTY_AFFECT(args...) = nothing @@ -985,9 +992,9 @@ end """ Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates. """ -function compile_equational_affect( - aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false, - eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs... +Base.@nospecializeinfer function compile_equational_affect( + @nospecialize(aff::Union{AffectSystem, Vector{Equation}}), sys; reset_jumps = false, + eval_expression = false, eval_module = @__MODULE__, @nospecialize(op = nothing), kwargs... ) if aff isa AbstractVector aff = make_affect(aff; iv = get_iv(sys)) diff --git a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl index a59e3efbbc..bbba53f470 100644 --- a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl @@ -246,8 +246,8 @@ generated functions, and `args` are the arguments. All other keyword arguments are forwarded to `build_function`. """ -function build_function_wrapper( - sys::AbstractSystem, expr, args...; p_start = 2, +Base.@nospecializeinfer function build_function_wrapper( + sys::AbstractSystem, @nospecialize(expr), @nospecialize(args...); p_start = 2, p_end = is_time_dependent(sys) ? length(args) - 1 : length(args), wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity, add_observed = true, filter_observed = Returns(true), diff --git a/lib/ModelingToolkitBase/src/systems/connectors.jl b/lib/ModelingToolkitBase/src/systems/connectors.jl index b9c70e9026..4098541256 100644 --- a/lib/ModelingToolkitBase/src/systems/connectors.jl +++ b/lib/ModelingToolkitBase/src/systems/connectors.jl @@ -153,7 +153,7 @@ function connector_type(sys::AbstractSystem) for s in unkvars vtype = get_connection_type(s) if vtype === Stream - isarray(s) && error("Array stream variables are not supported. Got $s.") + isarray(s) && error(lazy"Array stream variables are not supported. Got $s.") n_stream += 1 elseif vtype === Flow n_flow += 1 @@ -227,10 +227,10 @@ function validate_causal_variables_connection(allvars::Vector{SymbolicT}) for var in allvars vtype = getvariabletype(var) vtype === VARIABLE || - throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) + throw(ArgumentError(lazy"Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) end if !allunique(allvars) - throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) + throw(ArgumentError(lazy"Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) end sh1 = SU.shape(allvars[1])::SU.ShapeVecT sz1 = SU.SmallV{Int}() @@ -245,7 +245,7 @@ function validate_causal_variables_connection(allvars::Vector{SymbolicT}) push!(sz2, length(x)) end if !isequal(sz1, sz2) - throw(ArgumentError("Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively.")) + throw(ArgumentError(lazy"Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively.")) end end diff --git a/lib/ModelingToolkitBase/src/systems/imperative_affect.jl b/lib/ModelingToolkitBase/src/systems/imperative_affect.jl index d6b01738b6..02836bc2b6 100644 --- a/lib/ModelingToolkitBase/src/systems/imperative_affect.jl +++ b/lib/ModelingToolkitBase/src/systems/imperative_affect.jl @@ -238,7 +238,7 @@ function compile_functional_affect( for oexpr in obs_exprs invalid_vars = invalid_variables(sys, oexpr) if length(invalid_vars) > 0 - error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") + error(lazy"Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") end end end @@ -253,7 +253,7 @@ function compile_functional_affect( end invalid_vars = unassignable_variables(sys, mexpr) if length(invalid_vars) > 0 - error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + error(lazy"Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") end end end diff --git a/lib/ModelingToolkitBase/src/systems/parameter_buffer.jl b/lib/ModelingToolkitBase/src/systems/parameter_buffer.jl index 042629b30a..d69e0bd0ea 100644 --- a/lib/ModelingToolkitBase/src/systems/parameter_buffer.jl +++ b/lib/ModelingToolkitBase/src/systems/parameter_buffer.jl @@ -169,7 +169,7 @@ function MTKParameters( end end if !SU.isconst(val) - error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.") + error(lazy"Could not evaluate value of parameter $sym. Missing values for variables in expression $val.") end if ctype <: FnType ctype = fntype_to_function_type(ctype) diff --git a/lib/ModelingToolkitBase/src/systems/problem_utils.jl b/lib/ModelingToolkitBase/src/systems/problem_utils.jl index fd37948481..d895e77787 100644 --- a/lib/ModelingToolkitBase/src/systems/problem_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/problem_utils.jl @@ -556,7 +556,7 @@ function get_temporary_value(p, floatT = Float64) elseif stype <: AbstractArray zeros(eltype(stype), size(p)) else - error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization") + error(lazy"Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization") end end @@ -1727,6 +1727,16 @@ Macro for writing problem/function constructors. Expects a function definition w parameters for `iip` and `specialize`. Generates fallbacks with `specialize = SciMLBase.FullSpecialize` and `iip = true`. """ +# Unwrap `@nospecialize(arg)` to get the underlying argument expression. +# Returns the argument unchanged if not wrapped in @nospecialize. +function _unwrap_nospecialize(arg) + if Meta.isexpr(arg, :macrocall) && length(arg.args) >= 3 && + arg.args[1] in (Symbol("@nospecialize"), GlobalRef(Base, Symbol("@nospecialize"))) + return arg.args[3] + end + return arg +end + macro fallback_iip_specialize(ex) @assert Meta.isexpr(ex, :function) # fnname is ODEProblem{iip, spec}(args...) where {iip, spec} @@ -1745,8 +1755,27 @@ macro fallback_iip_specialize(ex) # the function should have keyword arguments @assert Meta.isexpr(args[1], :parameters) - # arguments to call with - call_args = map(args) do arg + # Create signature args with @nospecialize stripped (for fallback function signatures) + sig_args = map(args) do arg + unwrapped = _unwrap_nospecialize(arg) + # Handle :parameters specially - unwrap each kwarg inside + if Meta.isexpr(unwrapped, :parameters) + new_params = map(unwrapped.args) do kwarg + kw = _unwrap_nospecialize(kwarg) + # Convert :(=) to :kw if needed + if Meta.isexpr(kw, :(=)) + Expr(:kw, kw.args...) + else + kw + end + end + return Expr(:parameters, new_params...) + end + return unwrapped + end + + # arguments to call with (for forwarding calls) + call_args = map(sig_args) do arg # keyword args are in `Expr(:parameters)` so any `Expr(:kw)` here # are optional positional arguments. Analyze `:(f(a, b = 1; k = 1, l...))` # to understand @@ -1755,7 +1784,7 @@ macro fallback_iip_specialize(ex) end call_kwargs = map(call_args[1].args) do arg Meta.isexpr(arg, :...) && return arg - @assert Meta.isexpr(arg, :kw) + @assert Meta.isexpr(arg, :kw) "Expected keyword argument, got $(arg)" return Expr(:kw, arg.args[1], arg.args[1]) end call_args[1] = Expr(:parameters, call_kwargs...) @@ -1772,25 +1801,25 @@ macro fallback_iip_specialize(ex) ) # `ODEProblem{iip}` fnname_iip = Expr(:curly, fnname_name, curly_args[1]) - # `ODEProblem{iip}(args...)` - fncall_iip = Expr(:call, fnname_iip, args...) - # ODEProblem{iip}(args...) where {iip} + # `ODEProblem{iip}(sig_args...)` - use sig_args (no @nospecialize) for fallback signature + fncall_iip = Expr(:call, fnname_iip, sig_args...) + # ODEProblem{iip}(sig_args...) where {iip} fnwhere_iip = Expr(:where, fncall_iip, where_args[1]) fn_iip = Expr(:function, fnwhere_iip, callexpr_iip) # `ODEProblem{true}(call_args...)` callexpr_base = Expr(:call, Expr(:curly, fnname_name, true), call_args...) - # `ODEProblem(args...)` - fncall_base = Expr(:call, fnname_name, args...) + # `ODEProblem(sig_args...)` - use sig_args for fallback signature + fncall_base = Expr(:call, fnname_name, sig_args...) fn_base = Expr(:function, fncall_base, callexpr_base) # Handle case when this is a problem constructor and `u0map` is a `StaticArray`, # where `iip` should default to `false`. fn_sarr = nothing if occursin("Problem", string(fnname_name)) - # args should at least contain an argument for the `u0map` - @assert length(args) > 2 - u0_arg = args[3] + # sig_args should at least contain an argument for the `u0map` + @assert length(sig_args) > 2 + u0_arg = sig_args[3] # should not have a type-annotation @assert !Meta.isexpr(u0_arg, :(::)) if Meta.isexpr(u0_arg, :kw) @@ -1801,7 +1830,7 @@ macro fallback_iip_specialize(ex) end callexpr_sarr = Expr(:call, Expr(:curly, fnname_name, false), call_args...) - fncall_sarr = Expr(:call, fnname_name, args[1], args[2], u0_arg, args[4:end]...) + fncall_sarr = Expr(:call, fnname_name, sig_args[1], sig_args[2], u0_arg, sig_args[4:end]...) fn_sarr = Expr(:function, fncall_sarr, callexpr_sarr) end return quote diff --git a/lib/ModelingToolkitBase/src/systems/system.jl b/lib/ModelingToolkitBase/src/systems/system.jl index ecff2ee932..0ab26ae927 100644 --- a/lib/ModelingToolkitBase/src/systems/system.jl +++ b/lib/ModelingToolkitBase/src/systems/system.jl @@ -747,7 +747,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...) collect_vars!(noisedvs, noiseps, noiseeqs, iv) for dv in noisedvs dv ∈ allunknowns || - throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system.")) + throw(ArgumentError(lazy"Variable $dv in noise equations is not an unknown of the system.")) end end @@ -905,20 +905,20 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) if !iscall(var) SU.query(isequal(iv), var) && ( var ∈ sts || - throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")) + throw(ArgumentError(lazy"Time-dependent variable $var is not an unknown of the system.")) ) elseif length(arguments(var)) > 1 - throw(ArgumentError("Too many arguments for variable $var.")) + throw(ArgumentError(lazy"Too many arguments for variable $var.")) elseif length(arguments(var)) == 1 if iscall(var) && operation(var) isa Differential var = only(arguments(var)) end arg = only(arguments(var)) operation(var)(iv) ∈ sts || - throw(ArgumentError("Variable $var is not a variable of the System. Called variables must be variables of the System.")) + throw(ArgumentError(lazy"Variable $var is not a variable of the System. Called variables must be variables of the System.")) isequal(arg, iv) || isparameter(arg) || isconst(arg) && symtype(arg) <: Real || - throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) + throw(ArgumentError(lazy"Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) isparameter(arg) && !isequal(arg, iv) && push!(auxps, arg) else diff --git a/lib/ModelingToolkitBase/src/utils.jl b/lib/ModelingToolkitBase/src/utils.jl index 61edb259b2..c163f811a6 100644 --- a/lib/ModelingToolkitBase/src/utils.jl +++ b/lib/ModelingToolkitBase/src/utils.jl @@ -122,7 +122,7 @@ end function check_parameters(ps, iv) for p in ps isequal(iv, p) && - throw(ArgumentError("Independent variable $iv not allowed in parameters.")) + throw(ArgumentError(lazy"Independent variable $iv not allowed in parameters.")) end return end @@ -142,9 +142,9 @@ end function check_variables(dvs, iv) for dv in dvs isequal(iv, dv) && - throw(ArgumentError("Independent variable $iv not allowed in dependent variables.")) + throw(ArgumentError(lazy"Independent variable $iv not allowed in dependent variables.")) (is_delay_var(iv, dv) || SU.query(isequal(iv), dv)) || - throw(ArgumentError("Variable $dv is not a function of independent variable $iv.")) + throw(ArgumentError(lazy"Variable $dv is not a function of independent variable $iv.")) end return end @@ -154,7 +154,7 @@ function check_lhs(eq::Equation, ::Type{Differential}, dvs::Set) _iszero(v) && return op = operation(v) op isa Differential && isone(op.order) && only(arguments(v)) in dvs && return - error("$v is not a valid LHS. Please run mtkcompile before simulation.") + error(lazy"$v is not a valid LHS. Please run mtkcompile before simulation.") end function check_lhs(eqs::Vector{Equation}, ::Type{Differential}, dvs::Set) for eq in eqs @@ -199,7 +199,7 @@ function (icp::IndepvarCheckPredicate)(ex::SymbolicT) end @noinline function throw_multiple_iv(iv, newiv) - throw(ArgumentError("Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed.")) + throw(ArgumentError(lazy"Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed.")) end """ @@ -590,10 +590,10 @@ function check_operator_variables(eqs, ::Type{op}) where {op} is_tmp_fine = iszero(nd) end is_tmp_fine || - error("The LHS cannot contain nondifferentiated variables. Please run `mtkcompile` or use the DAE form.\nGot $eq") + error(lazy"The LHS cannot contain nondifferentiated variables. Please run `mtkcompile` or use the DAE form.\nGot $eq") for v in tmp v in ops && - error("The LHS operator must be unique. Please run `mtkcompile` or use the DAE form. $v appears in LHS more than once.") + error(lazy"The LHS operator must be unique. Please run `mtkcompile` or use the DAE form. $v appears in LHS more than once.") push!(ops, v) end empty!(tmp) @@ -1048,7 +1048,7 @@ function get_substitutions(sys) end @noinline function throw_missingvars_in_sys(vars) - throw(ArgumentError("$vars are either missing from the variable map or missing from the system's unknowns/parameters list.")) + throw(ArgumentError(lazy"$vars are either missing from the variable map or missing from the system's unknowns/parameters list.")) end function promote_to_concrete(vs; tofloat = true, use_union = true) @@ -1165,7 +1165,7 @@ Return the `DiCMOBiGraph` denoting the dependencies between observed equations ` function observed_dependency_graph(eqs::Vector{Equation}) for eq in eqs if symbolic_type(eq.lhs) == NotSymbolic() - error("All equations must be observed equations of the form `var ~ expr`. Got $eq") + error(lazy"All equations must be observed equations of the form `var ~ expr`. Got $eq") end end graph, assigns = observed2graph(eqs, getproperty.(eqs, (:lhs,))) @@ -1518,7 +1518,7 @@ function get_stable_index(x::SymbolicT) return Moshi.Match.@match x begin BSImpl.Term(; f, args) && if f === getindex end => return SU.StableIndex{Int}(x) BSImpl.Term(; f, args) && if f isa Operator end => return get_stable_index(args[1]) - _ => throw(ArgumentError("Invalid variable $x for `get_stable_index`.")) + _ => throw(ArgumentError(lazy"Invalid variable $x for `get_stable_index`.")) end end