Skip to content

Commit e65be37

Browse files
authored
Jdb/cat compat (#315)
* categorical init * test
1 parent f163c77 commit e65be37

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EvoTrees"
22
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
33
authors = ["jeremiedb <[email protected]>"]
4-
version = "0.18.1"
4+
version = "0.18.2"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
@@ -30,7 +30,7 @@ CUDA = "3.0, 4.0, 5.0"
3030
CategoricalArrays = "1"
3131
Distributions = "0.24, 0.25"
3232
KernelAbstractions = "0.9"
33-
MLJModelInterface = "0.3, 0.4, 1.0"
33+
MLJModelInterface = "1.2.1"
3434
NetworkLayout = "0.4"
3535
Random = "1"
3636
RecipesBase = "1"

benchmarks/softmax.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
using Revise
21
using Statistics
32
using StatsBase: sample
43
using XGBoost
54
using EvoTrees
5+
using EvoTrees: fit
66
using BenchmarkTools
77
import CUDA
88

@@ -54,22 +54,23 @@ params_evo = EvoTreeClassifier(;
5454
min_weight=1.0,
5555
rowsample=0.5,
5656
colsample=0.5,
57-
nbins=64)
57+
nbins=64
58+
)
5859

5960
@info "EvoTrees CPU"
6061
params_evo.device = :cpu
6162
@info "train - eval"
62-
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
63-
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
63+
@time m_evo = fit(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
64+
@time m_evo = fit(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
6465
@info "evotrees predict CPU:"
6566
@time pred_evo = m_evo(x_train);
6667
@btime m_evo($x_train);
6768

6869
@info "evotrees train GPU:"
6970
params_evo.device = :gpu
7071
@info "train - eval"
71-
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
72-
@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
72+
@time m_evo = fit(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
73+
@time m_evo = fit(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, print_every_n=100);
7374
# @btime fit_evotree($params_evo; x_train=$x_train, y_train=$y_train, x_eval=$x_train, y_eval=$y_train, metric=metric_evo);
7475
@info "evotrees predict GPU:"
7576
@time pred_evo = m_evo(x_train; device=:gpu);

ext/EvoTreesCUDAExt/init.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ function EvoTrees.init_core(params::EvoTrees.EvoTypes, ::Type{<:EvoTrees.GPU}, d
2727
target_isordered = EvoTrees.isordered(y_train)
2828
y = UInt32.(EvoTrees.CategoricalArrays.levelcode.(y_train))
2929
elseif eltype(y_train) <: Integer || eltype(y_train) <: Bool || eltype(y_train) <: String || eltype(y_train) <: Char
30-
target_levels = sort(unique(y_train))
31-
yc = EvoTrees.CategoricalVector(y_train, levels=target_levels)
30+
yc = EvoTrees.CategoricalArrays.categorical(y_train, levels=sort(unique(y_train)), ordered=false)
31+
target_levels = EvoTrees.CategoricalArrays.levels(yc)
3232
y = UInt32.(EvoTrees.CategoricalArrays.levelcode.(yc))
3333
else
3434
@error "Invalid target eltype: $(eltype(y_train))"

src/init.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ function init_core(params::EvoTypes, ::Type{CPU}, data, feature_names, y_train,
3030
target_isordered = isordered(y_train)
3131
y = UInt32.(CategoricalArrays.levelcode.(y_train))
3232
elseif eltype(y_train) <: Integer || eltype(y_train) <: Bool || eltype(y_train) <: String || eltype(y_train) <: Char
33-
target_levels = sort(unique(y_train))
34-
yc = CategoricalVector(y_train, levels=target_levels)
33+
yc = categorical(y_train, levels=sort(unique(y_train)), ordered=false)
34+
target_levels = CategoricalArrays.levels(yc)
3535
y = UInt32.(CategoricalArrays.levelcode.(yc))
3636
else
3737
@error "Invalid target eltype: $(eltype(y_train))"

0 commit comments

Comments
 (0)