Skip to content

Commit 6a0284e

Browse files
authored
fix #274 (#275)
* fix #274 fix GPU preds * fix #274 fix GPU preds * fix #274 fix GPU preds * fix #274 fix GPU preds
1 parent a16368b commit 6a0284e

File tree

7 files changed

+48
-14
lines changed

7 files changed

+48
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EvoTrees"
22
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
33
authors = ["jeremiedb <[email protected]>"]
4-
version = "0.16.8"
4+
version = "0.16.9"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"

ext/EvoTreesCUDAExt/init.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function EvoTrees.init_core(params::EvoTrees.EvoTypes{L}, ::Type{<:EvoTrees.GPU}
2323
elseif L == EvoTrees.MLogLoss
2424
if eltype(y_train) <: EvoTrees.CategoricalValue
2525
target_levels = EvoTrees.CategoricalArrays.levels(y_train)
26-
target_isordered = isordered(y_train)
26+
target_isordered = EvoTrees.isordered(y_train)
2727
y = UInt32.(EvoTrees.CategoricalArrays.levelcode.(y_train))
2828
elseif eltype(y_train) <: Integer || eltype(y_train) <: Bool || eltype(y_train) <: String || eltype(y_train) <: Char
2929
target_levels = sort(unique(y_train))

ext/EvoTreesCUDAExt/predict.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ function EvoTrees.predict!(
150150
end
151151

152152
# prediction for EvoTree model
153-
function predict(
154-
m::EvoTree{L,K},
153+
function EvoTrees._predict(
154+
m::EvoTrees.EvoTree{L,K},
155155
data,
156156
::Type{<:EvoTrees.GPU};
157157
ntree_limit=length(m.trees)) where {L,K}
158158

159-
Tables.istable(data) ? data = Tables.columntable(data) : nothing
159+
EvoTrees.Tables.istable(data) ? data = EvoTrees.Tables.columntable(data) : nothing
160160
ntrees = length(m.trees)
161161
ntree_limit > ntrees && error("ntree_limit is larger than number of trees $ntrees.")
162162
x_bin = CuArray(EvoTrees.binarize(data; fnames=m.info[:fnames], edges=m.info[:edges]))

src/models.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ function EvoTreeRegressor(; kwargs...)
7070
:rng => 123,
7171
)
7272

73+
args_ignored = setdiff(keys(kwargs), keys(args))
74+
length(args_ignored) > 0 &&
75+
@info "The following kwargs are not supported and will be ignored: $(args_ignored)."
76+
7377
args_override = intersect(keys(args), keys(kwargs))
7478
for arg in args_override
7579
args[arg] = kwargs[arg]
@@ -163,6 +167,10 @@ function EvoTreeCount(; kwargs...)
163167
:rng => 123,
164168
)
165169

170+
args_ignored = setdiff(keys(kwargs), keys(args))
171+
length(args_ignored) > 0 &&
172+
@info "The following kwargs are not supported and will be ignored: $(args_ignored)."
173+
166174
args_override = intersect(keys(args), keys(kwargs))
167175
for arg in args_override
168176
args[arg] = kwargs[arg]
@@ -231,6 +239,10 @@ function EvoTreeClassifier(; kwargs...)
231239
:rng => 123,
232240
)
233241

242+
args_ignored = setdiff(keys(kwargs), keys(args))
243+
length(args_ignored) > 0 &&
244+
@info "The following kwargs are not supported and will be ignored: $(args_ignored)."
245+
234246
args_override = intersect(keys(args), keys(kwargs))
235247
for arg in args_override
236248
args[arg] = kwargs[arg]
@@ -301,6 +313,10 @@ function EvoTreeMLE(; kwargs...)
301313
:rng => 123,
302314
)
303315

316+
args_ignored = setdiff(keys(kwargs), keys(args))
317+
length(args_ignored) > 0 &&
318+
@info "The following kwargs are not supported and will be ignored: $(args_ignored)."
319+
304320
args_override = intersect(keys(args), keys(kwargs))
305321
for arg in args_override
306322
args[arg] = kwargs[arg]
@@ -387,6 +403,10 @@ function EvoTreeGaussian(; kwargs...)
387403
:rng => 123,
388404
)
389405

406+
args_ignored = setdiff(keys(kwargs), keys(args))
407+
length(args_ignored) > 0 &&
408+
@info "The following kwargs are not supported and will be ignored: $(args_ignored)."
409+
390410
args_override = intersect(keys(args), keys(kwargs))
391411
for arg in args_override
392412
args[arg] = kwargs[arg]

src/predict.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,23 @@ function predict!(pred::Matrix{T}, tree::Tree{L,K}, x_bin::Matrix{UInt8}, featty
7474
return nothing
7575
end
7676

77+
7778
"""
78-
predict(model::EvoTree, X::AbstractMatrix; ntree_limit = length(model.trees))
79+
predict(m::EvoTree, data; ntree_limit=length(m.trees), device=:cpu)
7980
8081
Predictions from an EvoTree model - sums the predictions from all trees composing the model.
8182
Use `ntree_limit=N` to only predict with the first `N` trees.
8283
"""
83-
function predict(
84+
function predict(m::EvoTree, data; ntree_limit=length(m.trees), device=:cpu)
85+
@assert Symbol(device) [:cpu, :gpu]
86+
_device = Symbol(device) == :cpu ? CPU : GPU
87+
_predict(m, data, _device; ntree_limit)
88+
end
89+
90+
function _predict(
8491
m::EvoTree{L,K},
8592
data,
86-
::Type{<:Device}=CPU;
93+
::Type{<:CPU};
8794
ntree_limit=length(m.trees)) where {L,K}
8895

8996
Tables.istable(data) ? data = Tables.columntable(data) : nothing

src/structs.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,10 @@ struct EvoTree{L,K}
8787
trees::Vector{Tree{L,K}}
8888
info::Dict
8989
end
90-
# (m::EvoTree)(data, device::Type{D}=CPU; ntree_limit=length(m.trees)) where {D<:Device} =
91-
# predict(m, data, device; ntree_limit)
92-
function (m::EvoTree)(data; ntree_limit=length(m.trees), device="cpu")
93-
@assert string(device) ["cpu", "gpu"]
94-
_device = string(device) == "cpu" ? CPU : GPU
95-
return predict(m, data, _device; ntree_limit)
90+
function (m::EvoTree)(data; ntree_limit=length(m.trees), device=:cpu)
91+
@assert Symbol(device) [:cpu, :gpu]
92+
_device = Symbol(device) == :cpu ? CPU : GPU
93+
return _predict(m, data, _device; ntree_limit)
9694
end
9795

9896
_get_struct_loss(::EvoTree{L,K}) where {L,K} = L

test/issue-274.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using CUDA
2+
using EvoTrees
3+
using CategoricalArrays
4+
5+
x_train = x=rand(10, 2)
6+
y_train = categorical(rand("abc", 10), ordered=true)
7+
config = EvoTreeClassifier(; device=:gpu, L2=123)
8+
m = fit_evotree(config; x_train, y_train, device=:gpu)
9+
yhat = m(x_train; device=:gpu)

0 commit comments

Comments
 (0)