diff --git a/docs/src/tutorials/disturbance_modeling.md b/docs/src/tutorials/disturbance_modeling.md index 598697c71c..1fee6b88aa 100644 --- a/docs/src/tutorials/disturbance_modeling.md +++ b/docs/src/tutorials/disturbance_modeling.md @@ -203,35 +203,30 @@ using Test but we may also generate the functions ``f`` and ``g`` for state estimation: -!!! warning "Example currently disabled" - - This example is currently disabled due to compatibility issues with `generate_control_function` and analysis points in the current ModelingToolkit stack. - -```julia -inputs = [ssys.u] -disturbance_inputs = [ssys.d1, ssys.d2] -P = ssys.system_model +```@example DISTURBANCE_MODELING +P = model_with_disturbance.system_model outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w] (f_oop, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function( - model_with_disturbance, inputs; known_disturbance_inputs = disturbance_inputs) + model_with_disturbance, [:u]; known_disturbance_inputs = [:d1, :d2]) +inputs = ModelingToolkit.inputs(io_sys) g = ModelingToolkit.build_explicit_observed_function( - io_sys, outputs; inputs) + io_sys, outputs; inputs = inputs[1:1]) -op = ModelingToolkit.inputs(io_sys) .=> 0 +op = inputs .=> 0 x0 = ModelingToolkit.get_u0(io_sys, op) p = MTKParameters(io_sys, op) u = zeros(1) # Control input -w = zeros(length(disturbance_inputs)) # Disturbance input (known disturbances are provided as arguments) -@test f_oop(x0, u, p, t, w) == zeros(5) +w = zeros(length(inputs) - 1) # Disturbance input (known disturbances are provided as arguments) +@test f_oop(x0, u, p, 0.0, w) == zeros(5) @test g(x0, u, p, 0.0) == [0, 0, 0, 0] # Non-zero disturbance inputs should result in non-zero state derivatives. We call `sort` since we do not generally know the order of the state variables w = [1.0, 2.0] -@test sort(f_oop(x0, u, p, t, w)) == [0, 0, 0, 1, 2] +@test sort(f_oop(x0, u, p, 0.0, w)) == [0, 0, 0, 1, 2] ``` ## Input signal library diff --git a/lib/ModelingToolkitBase/src/systems/analysis_points.jl b/lib/ModelingToolkitBase/src/systems/analysis_points.jl index defd11aa91..3a6134be67 100644 --- a/lib/ModelingToolkitBase/src/systems/analysis_points.jl +++ b/lib/ModelingToolkitBase/src/systems/analysis_points.jl @@ -876,6 +876,7 @@ function generate_control_function( dist_ap_name::Union{ Nothing, Symbol, Vector{Symbol}, AnalysisPoint, Vector{AnalysisPoint}, } = nothing; + known_disturbance_inputs = nothing, system_modifier = identity, kwargs... ) @@ -885,16 +886,32 @@ function generate_control_function( sys, (du, _) = open_loop(sys, input_ap) push!(u, du) end - if dist_ap_name === nothing + + # Handle known disturbance inputs + kd = [] + if known_disturbance_inputs !== nothing + known_dist_ap = canonicalize_ap(sys, known_disturbance_inputs) + for dist_ap in known_dist_ap + sys, (du, _) = open_loop(sys, dist_ap) + push!(kd, du) + end + end + + if dist_ap_name === nothing && isempty(kd) return ModelingToolkitBase.generate_control_function(system_modifier(sys), u; kwargs...) end - dist_ap_name = canonicalize_ap(sys, dist_ap_name) d = [] - for dist_ap in dist_ap_name - sys, (du, _) = open_loop(sys, dist_ap) - push!(d, du) + if dist_ap_name !== nothing + dist_ap_name = canonicalize_ap(sys, dist_ap_name) + for dist_ap in dist_ap_name + sys, (du, _) = open_loop(sys, dist_ap) + push!(d, du) + end end - return ModelingToolkitBase.generate_control_function(system_modifier(sys), u, d; kwargs...) + return ModelingToolkitBase.generate_control_function( + system_modifier(sys), u, isempty(d) ? nothing : d; + known_disturbance_inputs = isempty(kd) ? nothing : kd, + kwargs...) end diff --git a/test/downstream/test_disturbance_model.jl b/test/downstream/test_disturbance_model.jl index 897add70c9..3cc3101aa8 100644 --- a/test/downstream/test_disturbance_model.jl +++ b/test/downstream/test_disturbance_model.jl @@ -198,6 +198,21 @@ f, x_sym, disturbance_argument = true, split = false ) +# Test symbol-based API with known_disturbance_inputs keyword +f_kd, x_sym_kd, p_sym_kd, io_sys_kd = ModelingToolkit.generate_control_function( + model_with_disturbance, [:u]; + known_disturbance_inputs = [:d1, :d2], split = false +) +@test length(ModelingToolkit.inputs(io_sys_kd)) == 1 + 2 # 1 control + 2 known disturbance +op_kd = ModelingToolkit.inputs(io_sys_kd) .=> 0 +x0_kd = ModelingToolkit.get_u0(io_sys_kd, op_kd) +p_kd = ModelingToolkit.get_p(io_sys_kd, op_kd) +u_kd = zeros(1) +w_kd = zeros(2) +@test f_kd[1](x0_kd, u_kd, p_kd, 0.0, w_kd) == zeros(length(x0_kd)) +w_kd = [1.0, 2.0] +@test sort(f_kd[1](x0_kd, u_kd, p_kd, 0.0, w_kd)) == [0, 0, 0, 1, 2] + measurement = ModelingToolkit.build_explicit_observed_function( io_sys, outputs, inputs = ModelingToolkit.inputs(io_sys)[1:1] )