Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make default value units consistent #2898

Merged
merged 12 commits into from
Aug 7, 2024
90 changes: 72 additions & 18 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function _model_macro(mod, name, expr, isconnector)
c_evts = []
d_evts = []
kwargs = OrderedCollections.OrderedSet()
where_types = Expr[]
where_types = Union{Symbol, Expr}[]

push!(exprs.args, :(variables = []))
push!(exprs.args, :(parameters = []))
Expand Down Expand Up @@ -143,9 +143,15 @@ end
pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key)

function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
if indices isa Nothing
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
if !isnothing(meta) && haskey(meta, VariableUnit)
uvar = gensym()
push!(where_types, uvar)
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $uvar}), nothing))
else
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $type}), nothing))
end
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
else
vartype = gensym(:T)
Expand All @@ -154,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
Expr(:(::), a,
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
nothing))
push!(where_types, :($vartype <: $type))
if !isnothing(meta) && haskey(meta, VariableUnit)
push!(where_types, vartype)
else
push!(where_types, :($vartype <: $type))
end
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
end
if dict[varclass] isa Vector
Expand All @@ -166,7 +176,7 @@ end

function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
type::Type = Real)
type::Type = Real, meta = Dict{DataType, Expr}())
metatypes = [(:connection_type, VariableConnectType),
(:description, VariableDescription),
(:unit, VariableUnit),
Expand All @@ -186,29 +196,31 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
a::Symbol => begin
var = generate_var!(dict, a, varclass; indices, type)
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
return var, def, Dict()
end
Expr(:(::), a, type) => begin
type = getfield(mod, type)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
end
Expr(:(::), Expr(:call, a, b), type) => begin
type = getfield(mod, type)
def = _type_check!(def, a, type, varclass)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
end
Expr(:call, a, b) => begin
var = generate_var!(dict, a, b, varclass, mod; indices, type)
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
varclass, where_types)
varclass, where_types, meta)
return var, def, Dict()
end
Expr(:(=), a, b) => begin
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; def, type)
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
if dict[varclass] isa Vector
dict[varclass][1][getname(var)][:default] = def
else
Expand All @@ -231,9 +243,9 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
return var, def, Dict()
end
Expr(:tuple, a, b) => begin
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; type)
meta = parse_metadata(mod, b)
var, def, _ = parse_variable_def!(
dict, mod, a, varclass, kwargs, where_types; type, meta)
if meta !== nothing
for (type, key) in metatypes
if (mt = get(meta, key, nothing)) !== nothing
Expand All @@ -253,7 +265,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
Expr(:ref, a, b...) => begin
indices = map(i -> UnitRange(i.args[2], i.args[end]), b)
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types;
def, indices, type)
def, indices, type, meta)
end
_ => error("$arg cannot be parsed")
end
Expand Down Expand Up @@ -611,16 +623,58 @@ function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_
push!(exprs, ex)
end

function convert_units(varunits::DynamicQuantities.Quantity, value)
DynamicQuantities.ustrip(DynamicQuantities.uconvert(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(
varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where {T}
DynamicQuantities.ustrip.(DynamicQuantities.uconvert.(
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
end

function convert_units(varunits::Unitful.FreeUnits, value)
Unitful.ustrip(varunits, value)
end

function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
Unitful.ustrip.(varunits, value)
end

function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
vv, def, metadata_with_exprs = parse_variable_def!(
dict, mod, arg, varclass, kwargs, where_types)
name = getname(vv)

varexpr = quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
varexpr = if haskey(metadata_with_exprs, VariableUnit)
unit = metadata_with_exprs[VariableUnit]
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
try
$setdefault($vv, $convert_units($unit, $name))
catch e
if isa(e, $(DynamicQuantities.DimensionError)) ||
isa(e, $(Unitful.DimensionError))
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
elseif isa(e, MethodError)
error("No or invalid units provided for \'" * string(:($$vv)) *
"\'")
else
rethrow(e)
end
end
end
end
else
quote
$name = if $name === nothing
$setdefault($vv, $def)
else
$setdefault($vv, $name)
end
end
end

Expand Down
28 changes: 28 additions & 0 deletions test/dq_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], ModelingToolkit.t_nounits, [S], [β, γ])

@mtkmodel ParamTest begin
@parameters begin
a, [unit = u"m"]
end
@variables begin
b(t), [unit = u"kg"]
end
end

@named sys = ParamTest()

@named sys = ParamTest(a = 3.0u"cm")
@test ModelingToolkit.getdefault(sys.a) ≈ 0.03

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) ≈ [0.01, 0.03]
32 changes: 17 additions & 15 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using ModelingToolkit, Test
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
get_systems, get_ps, getdefault, getname, readable_code,
scalarize, symtype, VariableDescription, RegularConnector
scalarize, symtype, VariableDescription, RegularConnector,
get_unit
using URIs: URI
using Distributions
using DynamicQuantities, OrdinaryDiffEq
Expand Down Expand Up @@ -53,8 +54,9 @@ end
end
end

@named p = Pin(; v = π)
@test getdefault(p.v) == π
@named p = Pin(; v = π * u"V")

@test getdefault(p.v) ≈ π
@test Pin.isconnector == true

@mtkmodel OnePort begin
Expand All @@ -76,7 +78,6 @@ end

@test OnePort.isconnector == false

resistor_log = "$(@__DIR__)/logo/resistor.svg"
@mtkmodel Resistor begin
@extend v, i = oneport = OnePort()
@parameters begin
Expand Down Expand Up @@ -105,14 +106,14 @@ end
@parameters begin
C, [unit = u"F"]
end
@extend OnePort(; v = 0.0)
@extend OnePort(; v = 0.0u"V")
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
@equations begin
D(v) ~ i / C
end
end

@named capacitor = Capacitor(C = 10, v = 10.0)
@named capacitor = Capacitor(C = 10u"F", v = 10.0u"V")
@test getdefault(capacitor.v) == 10.0

@mtkmodel Voltage begin
Expand All @@ -127,9 +128,9 @@ end

@mtkmodel RC begin
@structural_parameters begin
R_val = 10
C_val = 10
k_val = 10
R_val = 10u"Ω"
C_val = 10u"F"
k_val = 10u"V"
end
@components begin
resistor = Resistor(; R = R_val)
Expand All @@ -147,9 +148,9 @@ end
end
end

C_val = 20
R_val = 20
res__R = 100
C_val = 20u"F"
R_val = 20u"Ω"
res__R = 100u"Ω"
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
prob = ODEProblem(rc, [], (0, 1e9))
sol = solve(prob, Rodas5P())
Expand All @@ -160,11 +161,12 @@ resistor = getproperty(rc, :resistor; namespace = false)
@test getname(rc.resistor.R) === getname(resistor.R)
@test getname(rc.resistor.v) === getname(resistor.v)
# Test that `resistor.R` overrides `R_val` in the argument.
@test getdefault(rc.resistor.R) == res__R != R_val
@test getdefault(rc.resistor.R) * get_unit(rc.resistor.R) == res__R != R_val
# Test that `C_val` passed via argument is set as default of C.
@test getdefault(rc.capacitor.C) == C_val
@test getdefault(rc.capacitor.C) * get_unit(rc.capacitor.C) == C_val
# Test that `k`'s default value is unchanged.
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val][:value]
@test getdefault(rc.constant.k) * get_unit(rc.constant.k) ==
eval(RC.structure[:kwargs][:k_val][:value])
@test getdefault(rc.capacitor.v) == 0.0

@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==
Expand Down
28 changes: 28 additions & 0 deletions test/units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])

@mtkmodel ParamTest begin
@parameters begin
a, [unit = u"m"]
end
@variables begin
b(t), [unit = u"kg"]
end
end

@named sys = ParamTest()

@named sys = ParamTest(a = 3.0u"cm")
@test ModelingToolkit.getdefault(sys.a) ≈ 0.03

@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")

@mtkmodel ArrayParamTest begin
@parameters begin
a[1:2], [unit = u"m"]
end
end

@named sys = ArrayParamTest()

@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
@test ModelingToolkit.getdefault(sys.a) ≈ [0.01, 0.03]
Loading