Skip to content

Commit

Permalink
convert to Float32 in predict
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Oct 31, 2024
1 parent 68b401a commit 3315da4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function MLJModelInterface.predict(
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) # what if Xnew is a matrix
X = reformat(Xnew)
X = _f32(reformat(Xnew), 0)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end
Expand Down Expand Up @@ -69,7 +69,7 @@ function MLJModelInterface.predict(
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
X = reformat(Xnew)
X = _f32(reformat(Xnew), 0)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end
Expand Down
4 changes: 2 additions & 2 deletions src/regressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function MLJModelInterface.predict(model::NeuralNetworkRegressor,
Xnew)
chain, ordinal_mappings = fitresult[1], fitresult[3]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_ = reformat(Xnew)
Xnew_ = _f32(reformat(Xnew), 0)
return [chain(values.(tomat(Xnew_[:, i])))[1]
for i in 1:size(Xnew_, 2)]
end
Expand Down Expand Up @@ -74,7 +74,7 @@ function MLJModelInterface.predict(model::MultitargetNeuralNetworkRegressor,
fitresult, Xnew)
chain, target_column_names, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
X = reformat(Xnew)
X = _f32(reformat(Xnew), 0)
ypred = [chain(values.(tomat(X[:, i])))
for i in 1:size(X, 2)]
output =
Expand Down

0 comments on commit 3315da4

Please sign in to comment.