diff --git a/src/nlp/consistency.jl b/src/nlp/consistency.jl index 7c1be07..bcf140b 100644 --- a/src/nlp/consistency.jl +++ b/src/nlp/consistency.jl @@ -16,20 +16,21 @@ function consistent_nlps( nlps; exclude = [jth_hess, jth_hess_coord, jth_hprod, ghjvprod], linear_api = false, + reimplemented = [], test_meta = true, test_slack = true, test_qn = true, test_derivative = true, rtol = 1.0e-8, ) - consistent_counters(nlps, linear_api = linear_api) + consistent_counters(nlps, linear_api = linear_api, reimplemented = reimplemented) test_meta && consistent_meta(nlps, rtol = rtol) consistent_functions(nlps, linear_api = linear_api, rtol = rtol, exclude = exclude) - consistent_counters(nlps, linear_api = linear_api) + consistent_counters(nlps, linear_api = linear_api, reimplemented = reimplemented) for nlp in nlps reset!(nlp) end - consistent_counters(nlps, linear_api = linear_api) + consistent_counters(nlps, linear_api = linear_api, reimplemented = reimplemented) if test_derivative for nlp in nlps @test length(gradient_check(nlp)) == 0 @@ -50,7 +51,7 @@ function consistent_nlps( linear_api = linear_api, exclude = [hess, hess_coord, hprod, jth_hess, jth_hess_coord, jth_hprod, ghjvprod] ∪ exclude, ) - consistent_counters([nlps; qnmodels], linear_api = linear_api) + consistent_counters([nlps; qnmodels], linear_api = linear_api, reimplemented = reimplemented) end if test_slack && has_inequalities(nlps[1]) @@ -61,7 +62,7 @@ function consistent_nlps( linear_api = linear_api, exclude = [jth_hess, jth_hess_coord, jth_hprod] ∪ exclude, ) - consistent_counters(slack_nlps, linear_api = linear_api) + consistent_counters(slack_nlps, linear_api = linear_api, reimplemented = reimplemented) end end @@ -79,7 +80,7 @@ function consistent_meta(nlps; rtol = 1.0e-8) end end -function consistent_counters(nlps; linear_api = false) +function consistent_counters(nlps; linear_api = false, reimplemented = String[]) N = length(nlps) V = zeros(Int, N) check_fields = filter( @@ -96,8 +97,11 @@ function consistent_counters(nlps; linear_api = false) end if linear_api V = [sum_counters(nlp) for nlp in nlps] - @test all(V .== V[1]) + @test (reimplemented != []) | all(V .== V[1]) for field in setdiff(collect(fieldnames(Counters)), check_fields) + if any(x -> occursin(x, string(field)), reimplemented) + continue + end V = [eval(field)(nlp) for nlp in nlps] @testset "Field $field" begin for i = 1:(N - 1)