Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Probit node :out marginalisation not defined for q_in, which is needed for binary linear classification. #425

Open
wmkouw opened this issue Oct 30, 2024 · 6 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@wmkouw
Copy link
Member

wmkouw commented Oct 30, 2024

In 5SSD0, we have a simple binary classification model:

@model function linear_classification(y,X)
    
    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
    
    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X),
    returnvars  = (θ = KeepLast()),
    predictvars = (y = KeepLast()),
    iterations  = 10,
)

Requesting a prediction will throw:

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule Probit(:out, Marginalisation) (q_in::NormalMeanVariance, meta::ProbitMeta) = begin 
    return ...
end

ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:

@rule Probit(:out, Marginalisation) (m_in::UnivariateNormalDistributionsFamily, meta::Union{ProbitMeta, Nothing}) = begin
    p = normcdf(mean(m_in) / sqrt(1 + var(m_in)))
    return Bernoulli(p)
end

So, the rule does exist, but for m_in not q_in.

Seeing as binary classification is a pretty important use case, I think we need to have a rule for q_in. Can I just copy-paste or is there some reason we don't have a q_in rule?

@wmkouw wmkouw added enhancement New feature or request question Further information is requested labels Oct 30, 2024
@wmkouw wmkouw self-assigned this Oct 30, 2024
@albertpod
Copy link
Member

I think it again has to do something with missings in your data (see ReactiveBayes/RxInfer.jl#201). Probit node shouldn't be used in mean-field context.

This works:

using RxInfer
using Random

Random.seed!(123)

N = 100  
D = 2   

θ_true = [1.5, -1.0]  

X = randn(N, D)  
X_vector = [vec(X[i, :]) for i in 1:N]
z = X * θ_true + randn(N) * 0.1     
y = Float64.(z .> 0)  

@model function linear_classification(y, X)
    
    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
    
    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X_vector),
    iterations  = 10,
)

println(mean(results.posteriors[][end]))
println(θ_true)

@albertpod
Copy link
Member

albertpod commented Oct 30, 2024

If you want to get predictions out of your model with Probit, I suggest to write a different function for computing predictions. Besides, don't forget to use tuples in returnvars and predictvars, i.e. predictvars = (y = KeepLast(),)

@albertpod
Copy link
Member

albertpod commented Oct 30, 2024

@bvdmitri knows better, but the issue is that introducing predictvars enforces a different constraint (MF), which you don't do explicitly but it occurs behind the scenes, hence the error you're seeing makes sense. It's quite unfortunate, but it's not something that can be resolved easily.

@wouterwln
Copy link
Member

You can wrap data in UnfactorizedData to not enforce this MF constraint. so you'd get data = (y = UnfactorizedData(y), x = X_vector) and this should do SP message passing around that node.

I don't know what the rule for VMP would look like though. Maybe we can derive it tomorrow at the office.

@wmkouw
Copy link
Member Author

wmkouw commented Oct 30, 2024

Thanks for the quick response. There's no rush to this issue. I just wanted to discuss it.

Ok, so ReactiveMP is looking for a rule with q_in because predictvars enforces MeanField() even though I didn't specify MeanField(). That explains that.

But the current m_in rule is not SP, right? It's EP. Since EP is also variational, can't we just copy the m_in rule to q_in? I know it's not technically the solution to a variational rule, but we can just report that (maybe via the ProbitMeta).

@wouterwln
Copy link
Member

wouterwln commented Oct 31, 2024

Specifically, ReactiveMP assumes that data is always factorized out of the joint distribution. This is because if we actually supply data, we know that the posterior marginal for that datapoint is fixed and independent from any other posterior distributions (predictvars is considered data as well, as it is one of the interfaces that is unknown at model construction time. Since GraphPPL doesn't know at this point if you're going to pass data inside of these nodes or if you're going to pass missing, it will assume that it is data and factorize it out). This is one of the implicit assumptions we always make. Now in order to predict something, this might actually not be the case, and we don't always have to make this assumption. In order to override it, I added UnfactorizedData as a wrapper struct for any kind of data that won't automatically factorize out the posterior marginal distribution from the rest of the joint posterior marginal. So for prediction, it might make more sense to wrap it in UnfactorizedData since you might be able to send a SP message to the "data" instead of a VMP message.

As for your heuristic to send this message, there's nothing stopping us from using the EP message as a fallback, but I'd rather not have quick fixes in RMP without a proper derivation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants