-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Probit
node :out
marginalisation not defined for q_in
, which is needed for binary linear classification.
#425
Comments
I think it again has to do something with missings in your data (see ReactiveBayes/RxInfer.jl#201). Probit node shouldn't be used in mean-field context. This works: using RxInfer
using Random
Random.seed!(123)
N = 100
D = 2
θ_true = [1.5, -1.0]
X = randn(N, D)
X_vector = [vec(X[i, :]) for i in 1:N]
z = X * θ_true + randn(N) * 0.1
y = Float64.(z .> 0)
@model function linear_classification(y, X)
θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
for i in eachindex(y)
y[i] ~ Probit(dot(θ, X[i]))
end
end
results = infer(
model = linear_classification(),
data = (y = y, X = X_vector),
iterations = 10,
)
println(mean(results.posteriors[:θ][end]))
println(θ_true) |
If you want to get predictions out of your model with |
@bvdmitri knows better, but the issue is that introducing predictvars enforces a different constraint (MF), which you don't do explicitly but it occurs behind the scenes, hence the error you're seeing makes sense. It's quite unfortunate, but it's not something that can be resolved easily. |
You can wrap data in I don't know what the rule for VMP would look like though. Maybe we can derive it tomorrow at the office. |
Thanks for the quick response. There's no rush to this issue. I just wanted to discuss it. Ok, so ReactiveMP is looking for a rule with But the current |
Specifically, As for your heuristic to send this message, there's nothing stopping us from using the EP message as a fallback, but I'd rather not have quick fixes in RMP without a proper derivation. |
In 5SSD0, we have a simple binary classification model:
Requesting a prediction will throw:
ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:
So, the rule does exist, but for
m_in
notq_in
.Seeing as binary classification is a pretty important use case, I think we need to have a rule for
q_in
. Can I just copy-paste or is there some reason we don't have aq_in
rule?The text was updated successfully, but these errors were encountered: