Skip to content

Commit f9210c7

Browse files
authored
Jdb/bagging (#286)
* WIP bagging * up * up * cred blog * fix base level for logloss * include L2 in cred predict * up * up * up * fix cred gainÉ * up * update blog * update blog * update blog * update blog * update blog * update blog * fix MAE * refresh benchmarks cred_std * benchmarks * fix * up docs * gpu support for cred * benchmark results * up * bump version * cleanup assets
1 parent 2435c1c commit f9210c7

File tree

100 files changed

+1643
-232
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+1643
-232
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EvoTrees"
22
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
33
authors = ["jeremiedb <[email protected]>"]
4-
version = "0.17.2"
4+
version = "0.18.0"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
using Revise
21
using CSV
32
using DataFrames
4-
using EvoTrees
53
using StatsBase: sample, tiedrank
64
using Statistics
75
using Random: seed!
6+
using EvoTrees
7+
using EvoTrees: fit
88

99
using AWS: AWSCredentials, AWSConfig, @service
1010
@service S3
1111
aws_creds = AWSCredentials(ENV["AWS_ACCESS_KEY_ID_JDB"], ENV["AWS_SECRET_ACCESS_KEY_JDB"])
12-
aws_config = AWSConfig(; creds = aws_creds, region = "ca-central-1")
12+
aws_config = AWSConfig(; creds=aws_creds, region="ca-central-1")
1313

1414
path = "share/data/year/year.csv"
1515
raw = S3.get_object(
@@ -18,7 +18,7 @@ raw = S3.get_object(
1818
Dict("response-content-type" => "application/octet-stream");
1919
aws_config,
2020
)
21-
df = DataFrame(CSV.File(raw, header = false))
21+
df = DataFrame(CSV.File(raw, header=false))
2222

2323
path = "share/data/year/year-train-idx.txt"
2424
raw = S3.get_object(
@@ -27,7 +27,7 @@ raw = S3.get_object(
2727
Dict("response-content-type" => "application/octet-stream");
2828
aws_config,
2929
)
30-
train_idx = DataFrame(CSV.File(raw, header = false))[:, 1] .+ 1
30+
train_idx = DataFrame(CSV.File(raw, header=false))[:, 1] .+ 1
3131

3232
path = "share/data/year/year-eval-idx.txt"
3333
raw = S3.get_object(
@@ -36,50 +36,41 @@ raw = S3.get_object(
3636
Dict("response-content-type" => "application/octet-stream");
3737
aws_config,
3838
)
39-
eval_idx = DataFrame(CSV.File(raw, header = false))[:, 1] .+ 1
39+
eval_idx = DataFrame(CSV.File(raw, header=false))[:, 1] .+ 1
4040

4141
X = df[:, 2:end]
4242
Y_raw = Float64.(df[:, 1])
4343
Y = (Y_raw .- mean(Y_raw)) ./ std(Y_raw)
4444

45-
function percent_rank(x::AbstractVector{T}) where {T}
46-
return tiedrank(x) / (length(x) + 1)
47-
end
48-
49-
transform!(X, names(X) .=> percent_rank .=> names(X))
50-
X = collect(Matrix{Float32}(X))
51-
Y = Float32.(Y)
52-
5345
x_tot, y_tot = X[1:(end-51630), :], Y[1:(end-51630)]
54-
x_test, y_test = X[(end-51630+1):end, :], Y[(end-51630+1):end]
55-
x_train, x_eval = x_tot[train_idx, :], x_tot[eval_idx, :]
46+
x_test, y_test = Matrix(X[(end-51630+1):end, :]), Y[(end-51630+1):end]
47+
x_train, x_eval = Matrix(x_tot[train_idx, :]), Matrix(x_tot[eval_idx, :])
5648
y_train, y_eval = y_tot[train_idx], y_tot[eval_idx]
5749

5850
config = EvoTreeRegressor(
59-
T = Float32,
60-
nrounds = 1200,
61-
loss = :linear,
62-
eta = 0.1,
63-
nbins = 128,
64-
min_weight = 4,
65-
max_depth = 7,
66-
lambda = 0,
67-
gamma = 0,
68-
rowsample = 0.8,
69-
colsample=0.8,
51+
nrounds=3000,
52+
loss=:cred_std,
53+
metric=:mse,
54+
eta=0.1,
55+
nbins=32,
56+
min_weight=1,
57+
max_depth=7,
58+
lambda=0,
59+
L2=0,
60+
gamma=0,
61+
rowsample=0.5,
62+
colsample=0.9,
63+
early_stopping_rounds=50,
7064
)
7165

7266
# @time m = fit_evotree(config; x_train, y_train, print_every_n=25);
73-
@time m, logger = fit_evotree(
67+
@time m = fit(
7468
config;
7569
x_train,
7670
y_train,
7771
x_eval,
7872
y_eval,
79-
early_stopping_rounds = 100,
80-
print_every_n = 10,
81-
metric = :mse,
82-
return_logger = true,
73+
print_every_n=100,
8374
);
8475
p_evo = m(x_test);
8576
mean((p_evo .- y_test) .^ 2) * std(Y_raw)^2

benchmarks/boston.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
using EvoTrees
3+
using MLDatasets
4+
using DataFrames
5+
using Statistics: mean
6+
using CategoricalArrays
7+
using Random
8+
9+
df = MLDatasets.BostonHousing().dataframe
10+
Random.seed!(123)
11+
12+
train_ratio = 0.8
13+
train_indices = randperm(nrow(df))[1:Int(round(train_ratio * nrow(df)))]
14+
15+
train_data = df[train_indices, :]
16+
eval_data = df[setdiff(1:nrow(df), train_indices), :]
17+
18+
x_train, y_train = Matrix(train_data[:, Not(:MEDV)]), train_data[:, :MEDV]
19+
x_eval, y_eval = Matrix(eval_data[:, Not(:MEDV)]), eval_data[:, :MEDV]
20+
21+
config = EvoTreeRegressor(
22+
loss=:mse,
23+
metric=:mse,
24+
nrounds=1,
25+
early_stopping_rounds=10,
26+
eta=0.1,
27+
max_depth=2,
28+
lambda=0.0,
29+
L2=0.0,
30+
rowsample=0.9,
31+
colsample=0.9)
32+
33+
model_mse = EvoTrees.fit(config;
34+
x_train, y_train,
35+
x_eval, y_eval,
36+
print_every_n=1)
37+
38+
pred_train = model(x_train)
39+
pred_eval = model(x_eval)
40+
41+
mean(abs.(pred_train .- y_train))
42+
mean(abs.(pred_eval .- y_eval))
43+
Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
device,nobs,nfeats,max_depth,train_evo,train_xgb,infer_evo,infer_xgb
2-
cpu,100000,10,6,0.339105894,0.642794662,0.045756982,0.02628743
3-
cpu,100000,10,11,1.279537507,1.073844805,0.085824892,0.06088243
4-
cpu,100000,100,6,0.825591006,1.521080299,0.068610545,0.14405631
5-
cpu,100000,100,11,4.875921204,3.93826447,0.11966227,0.168092254
6-
cpu,1000000,10,6,2.310057563,6.782713955,0.295144245,0.283113086
7-
cpu,1000000,10,11,5.079728577,8.015394605,0.802753618,0.60897193
8-
cpu,1000000,100,6,5.724386557,13.513077202,0.688739903,1.272025185
9-
cpu,1000000,100,11,18.003480355,21.454809233,1.247011838,1.657717943
10-
cpu,10000000,10,6,27.055199606,85.252937661,2.921450187,2.888122252
11-
cpu,10000000,10,11,52.143569851,111.505335039,6.18255143,6.202632593
12-
cpu,10000000,100,6,83.326695985,144.605970885,6.047807335,14.620566726
13-
cpu,10000000,100,11,194.955017106,182.237153757,11.50293827,17.660819455
2+
cpu,100000,10,6,0.330517856,0.62067627,0.045004986,0.044798794
3+
cpu,100000,10,11,1.337956436,1.105285991,0.086839406,0.061570974
4+
cpu,100000,100,6,0.828781594,1.363081129,0.106595691,0.119703591
5+
cpu,100000,100,11,4.941785747,3.435747012,0.122107048,0.166959767
6+
cpu,1000000,10,6,2.314222299,6.57856163,0.3913027,0.364978734
7+
cpu,1000000,10,11,5.170780341,8.45243535,0.611472906,0.612859723
8+
cpu,1000000,100,6,5.6716359,14.418231971,0.721386688,1.295978265
9+
cpu,1000000,100,11,18.040254949,18.531543442,1.360270762,1.75281179
10+
cpu,10000000,10,6,25.582933653,78.774728198,2.748972478,2.744420644
11+
cpu,10000000,10,11,51.265034576,112.372616748,6.100088688,6.337031412
12+
cpu,10000000,100,6,81.37971803,146.266650929,5.952960637,14.103855381
13+
cpu,10000000,100,11,190.794016029,189.25733363,11.847299848,18.792787023
Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
device,nobs,nfeats,max_depth,train_evo,train_xgb,infer_evo,infer_xgb
2-
gpu,100000,10,6,1.270319267,0.285433309,0.010683488,0.012045765
3-
gpu,100000,10,11,15.401192763,1.308309359,0.011901549,0.016735511
4-
gpu,100000,100,6,1.656181006,0.617481512,0.034038548,0.106893283
5-
gpu,100000,100,11,19.647128327,3.209314346,0.038993915,0.162476102
6-
gpu,1000000,10,6,2.033918292,0.955982504,0.051210817,0.131093957
7-
gpu,1000000,10,11,23.490242119,2.714125028,0.059398531,0.144796451
8-
gpu,1000000,100,6,3.424353046,2.866238074,0.307580376,1.342028138
9-
gpu,1000000,100,11,30.398456011,7.88853843,0.352188155,1.651248449
10-
gpu,10000000,10,6,7.552837802,7.424535739,0.457648127,1.604205225
11-
gpu,10000000,10,11,39.834112089,13.51496456,0.577825194,1.763046
12-
gpu,10000000,100,6,21.76585393,28.380947932,3.282138258,14.587210604
13-
gpu,10000000,100,11,66.83786749,53.762559553,3.620799047,17.43789161
2+
gpu,100000,10,6,1.262735022,0.319433525,0.010457942,0.012211455
3+
gpu,100000,10,11,15.66936406,1.494096649,0.01337129,0.017512698
4+
gpu,100000,100,6,1.756429163,0.675154915,0.034648695,0.148641248
5+
gpu,100000,100,11,20.447355358,3.821046349,0.038901134,0.162287052
6+
gpu,1000000,10,6,2.215962749,1.049401344,0.05478112,0.134286529
7+
gpu,1000000,10,11,24.254557497,3.112351903,0.061374392,0.157115342
8+
gpu,1000000,100,6,3.635739525,3.228633649,0.307356514,1.361765356
9+
gpu,1000000,100,11,31.102936915,8.530664774,0.312302753,1.61460587
10+
gpu,10000000,10,6,8.384827155,7.755134961,0.457555156,1.626715379
11+
gpu,10000000,10,11,42.062736926,13.615783394,0.58097717,1.723417395
12+
gpu,10000000,100,6,21.687369289,28.868658021,3.237098864,14.680655122
13+
gpu,10000000,100,11,68.618695095,57.845418449,3.440632989,16.870168538

blog/cred/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# assets/

blog/cred/README.jl

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# # Exploring a credibility-based approach for tree-gain estimation
2+
3+
include(joinpath(@__DIR__, "utils.jl")); #hide
4+
5+
#=
6+
> The motivation for this experiment was to explore an alternative to gradient-based gain measure by integrating the volatility of split candidates to identity the best node split.
7+
=#
8+
9+
#=
10+
11+
## Review of key gradient-based MSE characteristics
12+
13+
The figures below illustrate the behavior of vanilla gradient-based approach using a mean-squarred error (MSE) loss.
14+
The 2 colors represent the observations belonging to the left and right children.
15+
16+
Key observations:
17+
- **the gain is invariant to the volatility**: the top vs bottom figures differs only by the std dev of the observations.
18+
The associated gain is identical, which is aligned with the gradient-based approach to gain: the gain matches the reduction in the MSE, which is identical regardless of the dispersion. It's strictly driven by their mean.
19+
- **the gain scales linearly with the number of observations**: the right vs left figures contrasts different number of observations (100 vs 10k), and show that gain is directly proportional.
20+
- **the gain scales quadratically with the spread**: moving from a spread of 1.0 to 0.1 between the 2nd and 3rd row results in a drop by 100x of the gain: from 50.0 to 0.5.
21+
=#
22+
23+
loss = :mse#hide
24+
f = get_dist_figure(; loss, nobs=100, spread=1.0, sd=1.0)#hide
25+
save(joinpath(@__DIR__, "assets", "dist-mse-1A.png"), f)#hide
26+
f = get_dist_figure(; loss, nobs=1_000, spread=1.0, sd=1.0)#hide
27+
save(joinpath(@__DIR__, "assets", "dist-mse-1B.png"), f)#hide
28+
f = get_dist_figure(; loss, nobs=100, spread=1.0, sd=0.1)#hide
29+
save(joinpath(@__DIR__, "assets", "dist-mse-2A.png"), f)#hide
30+
f = get_dist_figure(; loss, nobs=1_000, spread=1.0, sd=0.1)#hide
31+
save(joinpath(@__DIR__, "assets", "dist-mse-2B.png"), f);#hide
32+
f = get_dist_figure(; loss, nobs=100, spread=0.1, sd=0.1)#hide
33+
save(joinpath(@__DIR__, "assets", "dist-mse-3A.png"), f)#hide
34+
f = get_dist_figure(; loss, nobs=1_000, spread=0.1, sd=0.1)#hide
35+
save(joinpath(@__DIR__, "assets", "dist-mse-3B.png"), f);#hide
36+
37+
#=
38+
| ![](assets/dist-mse-1A.png) | ![](assets/dist-mse-1B.png) |
39+
|:----------------------:|:----------------------:|
40+
| ![](assets/dist-mse-2A.png) | ![](assets/dist-mse-3A.png) |
41+
=#
42+
43+
#=
44+
## Credibility-based gains
45+
=#
46+
47+
#=
48+
The idea is for *gain* to reflect varying uncertainty levels for observations associated to each of the tree-split candidates.
49+
For tree-split candidates with an identical spread, the intuition is that candidates with a lower volatility, all other things being equal, should be preferred.
50+
The original inspiration comes from credibility theory, a foundational notion in actuarial science with direct connexion mixed effect models and bayesian theory.
51+
Key concept is that the credibility associated with a set of observations is driven by the relative effect of 2 components:
52+
- **Variance of the Hypothetical Means (VHM)**: if large differences between candidates means are expected, a greater credibility is assigned.
53+
- **Expected Value of the Process Variance (EVPV)**: if the data generation process of a given candidate has a large volatility, a smaller credibility is assigned.
54+
The Buhlmann credibility states that the optimal linear posterior estimator of a group mean is:
55+
- `Z * X̄ + (1 - Z) * μ`, where `X̄` is the group mean and `μ` the population mean.
56+
=#
57+
58+
#=
59+
This approach results in a shift of perspective in how the gain is derived.
60+
Classical gradient based is about deriving a second-order approximation of the loss curve for a tre-split candidate.
61+
The gain corresponds to the reduction in this approximated loss by taking the prediciton that minimises the quadratic loss curve.
62+
The credibility-based takes a loss function agnostic approach, and view the gain as the total absolute change in the credibility-adjusted predicted value.
63+
Example, if a child has a mean residual of *2.0*, credibility of 0.5 and 100 observations, the resulting gain is: `2.0 * 0.5 * 100 = 100.0`, where `2.0 * 0.5` corresponds to the credibility adjusted prediction.
64+
65+
VHM is estimated as the square of the mean of the spread between observed values and predictions:
66+
- `VHM = E[X] = mean(y - p)`
67+
68+
EVPV is estimated as the variance of the observations. This value can be derived from the aggregation of the first and second moment of the individual observations:
69+
- `EVPV = E[(x - μ)²] = E[X²] - E²[X]`
70+
=#
71+
72+
#=
73+
## Credibility-based losses in EvoTrees
74+
Two credibility-based losses are supported with `EvoTreeRegressor`:
75+
- **cred_var**: `VHM / (VHM + EVPV)`
76+
- **cred_std**: `sqrt(VHM) / (sqrt(VHM) + sqrt(EVPV))`
77+
=#
78+
79+
80+
#=
81+
Just like the gradient-based MSE error, the gain grows linearly with the number of observations, all other things being equal.
82+
However, a smaller volatility results in an increased gain, as shown in 2nd vs 1st row.
83+
=#
84+
85+
loss = :cred_std#hide
86+
f = get_dist_figure(; loss, nobs=100, spread=1.0, sd=1.0)#hide
87+
save(joinpath(@__DIR__, "assets", "dist-cred_std-1A.png"), f);#hide
88+
f = get_dist_figure(; loss, nobs=1_000, spread=1.0, sd=1.0)#hide
89+
save(joinpath(@__DIR__, "assets", "dist-cred_std-1B.png"), f);#hide
90+
f = get_dist_figure(; loss, nobs=100, spread=1.0, sd=0.1)#hide
91+
save(joinpath(@__DIR__, "assets", "dist-cred_std-2A.png"), f);#hide
92+
f = get_dist_figure(; loss, nobs=1_000, spread=1.0, sd=0.1)#hide
93+
save(joinpath(@__DIR__, "assets", "dist-cred_std-2B.png"), f);#hide
94+
f = get_dist_figure(; loss, nobs=100, spread=0.1, sd=0.1)#hide
95+
save(joinpath(@__DIR__, "assets", "dist-cred_std-3A.png"), f);#hide
96+
f = get_dist_figure(; loss, nobs=1_000, spread=0.1, sd=0.1)#hide
97+
save(joinpath(@__DIR__, "assets", "dist-cred_std-3B.png"), f);#hide
98+
99+
#=
100+
| ![](assets/dist-cred_std-1A.png) | ![](assets/dist-cred_std-1B.png) |
101+
|:----------------------:|:----------------------:|
102+
| ![](assets/dist-cred_std-2A.png) | ![](assets/dist-cred_std-3A.png) |
103+
=#
104+
105+
# ### Simulation grid
106+
107+
#=
108+
The chart below show the associated credibility and gain for a given node split candidate for various spreads and standards deviations.
109+
=#
110+
111+
nobs = 1000
112+
sd_list = [0.01, 0.05, 0.1, 0.2, 0.5, 1, 2, 5]
113+
spread_list = [0.01, 0.05, 0.1, 0.2, 0.5, 1]
114+
metric_name = "cred"#hide
115+
f = get_cred_figureB(; metric_name, loss=:cred_std, nobs, sd_list, spread_list)#hide
116+
save(joinpath(@__DIR__, "assets", "heatmap-$metric_name-cred_std.png"), f);#hide
117+
metric_name = "gain"#hide
118+
f = get_cred_figureB(; metric_name, loss=:cred_std, nobs, sd_list, spread_list)#hide
119+
save(joinpath(@__DIR__, "assets", "heatmap-$metric_name-cred_std.png"), f);#hide
120+
#=
121+
| ![](assets/heatmap-cred-cred_std.png) | ![](assets/heatmap-gain-cred_std.png) |
122+
|:----------------------:|:----------------------:|
123+
=#
124+
125+
# ### Illustration of different cred-based decision between `cred_std` to `MSE`
126+
127+
#=
128+
Despite both `mse` and `cred_std` resulting in the same prediction, which matches the mean of the observations, the associated gain differs due to the volatility penalty.
129+
130+
The following illustrates a minimal scenario of 2 features, each with only 2 levels.
131+
=#
132+
133+
#=
134+
| ![](assets/dist-mse-cred-x1.png) | ![](assets/dist-mse-cred-x2.png) |
135+
|:----------------------:|:----------------------:|
136+
=#
137+
138+
#=
139+
```julia
140+
config = EvoTreeRegressor(loss=:mse, nrounds=1, max_depth=2)
141+
model_mse = EvoTrees.fit(config, dtrain; target_name="y")
142+
143+
EvoTrees.Tree{EvoTrees.MSE, 1}
144+
- feat: [2, 0, 0]
145+
- cond_bin: UInt8[0x01, 0x00, 0x00]
146+
- gain: Float32[12113.845, 0.0, 0.0]
147+
- pred: Float32[0.0 -0.017858343 0.3391479]
148+
- split: Bool[1, 0, 0]
149+
```
150+
=#
151+
152+
#=
153+
```julia
154+
config = EvoTreeRegressor(loss=:cred_std, nrounds=1, max_depth=2)
155+
model_std = EvoTrees.fit(config, dtrain; target_name="y")
156+
157+
EvoTrees.Tree{EvoTrees.CredStd, 1}
158+
- feat: [1, 0, 0]
159+
- cond_bin: UInt8[0x02, 0x00, 0x00]
160+
- gain: Float32[8859.706, 0.0, 0.0]
161+
- pred: Float32[0.0 0.07375729 -0.07375729]
162+
- split: Bool[1, 0, 0]
163+
```
164+
=#
165+
166+
#=
167+
## Benchmarks
168+
169+
From [MLBenchmarks.jl](https://github.com/Evovest/MLBenchmarks.jl).
170+
171+
| **model** | **metric** | **mse** | **cred_var** | **cred_std** |
172+
|:---------:|:----------:|:-------:|:------------:|:------------:|
173+
| boston | mse | 6.3 | 5.95 | 5.43 |
174+
| boston | gini | 0.945 | 0.947 | 0.952 |
175+
| year | mse | 74.9 | 74.6 | 74.2 |
176+
| year | gini | 0.662 | 0.664 | 0.661 |
177+
| msrank | mse | 0.55 | 0.551 | 0.549 |
178+
| msrank | ndcg | 0.511 | 0.509 | 0.51 |
179+
| yahoo | mse | 0.565 | 0.589 | 0.568 |
180+
| yahoo | ndcg | 0.795 | 0.787 | 0.794 |
181+
182+
=#

0 commit comments

Comments
 (0)