Skip to content

Commit

Permalink
give a hint if a user uses = operator wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 2, 2024
1 parent b5521a5 commit 78b22ca
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/GraphPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,20 @@ macro model(model_specification)
return esc(GraphPPL.model_macro_interior(DefaultBackend, model_specification))
end

function __init__()
if isdefined(Base.Experimental, :register_error_hint)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs
if any(x -> x <: VariableRef, argtypes)
print(io, "\nOne of the arguments to ")
printstyled(io, "`$(exc.f)`", color = :cyan)
print(io, " is of type ")
printstyled(io, "`GraphPPL.VariableRef`", color = :cyan)
print(io, ". Did you mean to create a new random variable with ")
printstyled(io, "`:=`", color = :cyan)
print(io, " operator instead?")
end
end
end
end

end # module
13 changes: 13 additions & 0 deletions test/model_macro_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2030,3 +2030,16 @@ end

@test !isnothing(GraphPPL.create_model(somemodel(a = 1, b = 2)))
end

@testitem "model should warn users against incorrect usages of `=` operator with random variables" begin
using GraphPPL, Distributions
import GraphPPL: @model

@model function somemodel()
a ~ Normal(0, 1)
t = exp(a)
y ~ Normal(0, t)
end

@test_throws "One of the arguments to `exp` is of type `GraphPPL.VariableRef`. Did you mean to create a new random variable with `:=` operator instead?" GraphPPL.create_model(somemodel())
end

0 comments on commit 78b22ca

Please sign in to comment.