Skip to content

Commit f4fb946

Browse files
authored
Jdb/api (#279)
* minor fit init simplification * fix tutorials reproducibility
1 parent a40c9dd commit f4fb946

File tree

6 files changed

+10
-15
lines changed

6 files changed

+10
-15
lines changed

docs/src/tutorials/classification-iris.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ config = EvoTreeClassifier(
5050
rowsample=0.8,
5151
colsample=0.8)
5252

53-
model = fit(config, dtrain;
53+
model = EvoTrees.fit(config, dtrain;
5454
target_name,
5555
feature_names,
5656
deval,

docs/src/tutorials/examples-API.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Minimal example to fit a noisy sinus wave.
1010

1111
```julia
1212
using EvoTrees
13+
using EvoTrees: fit
1314
using EvoTrees: sigmoid, logit
1415
using StatsBase: sample
1516

@@ -48,7 +49,7 @@ config = EvoTreeRegressor(
4849
max_depth=6, min_weight=1.0,
4950
rowsample=0.5, colsample=1.0)
5051

51-
model = fit_evotree(config; x_train, y_train, x_eval, y_eval, print_every_n=25)
52+
model = fit(config; x_train, y_train, x_eval, y_eval, print_every_n=25)
5253
pred_eval_logistic = model(x_eval)
5354

5455
# L1

docs/src/tutorials/logistic-regression-titanic.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ To begin, we will load the required packages and the dataset:
1010
using EvoTrees
1111
using MLDatasets
1212
using DataFrames
13-
using Statistics: mean
13+
using Statistics: mean, median
1414
using CategoricalArrays
1515
using Random
1616

@@ -71,7 +71,7 @@ config = EvoTreeRegressor(
7171
rowsample=0.5,
7272
colsample=0.9)
7373

74-
model = fit(
74+
model = EvoTrees.fit(
7575
config, dtrain;
7676
deval,
7777
target_name,

docs/src/tutorials/ranking-LTRC.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ config = EvoTreeRegressor(
7676
colsample=0.9,
7777
)
7878

79-
m_mse, logger_mse = fit(
79+
m_mse, logger_mse = EvoTrees.fit(
8080
config;
8181
x_train=x_train,
8282
y_train=y_train,
@@ -146,7 +146,7 @@ config = EvoTreeRegressor(
146146
colsample=0.9,
147147
)
148148

149-
m_logloss, logger_logloss = fit_evotree(
149+
m_logloss, logger_logloss = EvoTrees.fit(
150150
config;
151151
x_train=x_train,
152152
y_train=y_train,

docs/src/tutorials/regression-boston.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ config = EvoTreeRegressor(
4949
rowsample=0.9,
5050
colsample=0.9)
5151

52-
model = fit(config;
52+
model = EvoTrees.fit(config;
5353
x_train, y_train,
5454
x_eval, y_eval,
5555
print_every_n=10)

src/fit.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,8 @@ function fit(
358358
_device = params.device == :gpu ? GPU : CPU
359359
m, cache = init(params, dtrain, _device; target_name, feature_names, weight_name, offset_name)
360360

361-
# initialize callback and logger if tracking eval data
362-
metric = params.metric
363-
logging_flag = !isnothing(deval)
364-
any_flag = !isnothing(deval)
365-
if !logging_flag && any_flag
366-
@warn "To track eval metric in logger, `deval` must be provided."
367-
end
368-
if logging_flag
361+
# initialize callback and logger if deval is provided
362+
if !isnothing(deval)
369363
deval = Tables.columntable(deval)
370364
cb = CallBack(params, m, deval, _device; target_name, weight_name, offset_name)
371365
logger = init_logger(; metric=params.metric, maximise=is_maximise(cb.feval), params.early_stopping_rounds)

0 commit comments

Comments
 (0)