-
-
Notifications
You must be signed in to change notification settings - Fork 8
Benchmark/rf use case #294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 21 commits
2380913
86f87c8
400ed74
9e6acd8
fc4f2fa
81d1ded
cb03eb3
78b95a5
a365757
43a8ffb
6b9a845
d354b2c
b5b27b1
565456b
7c9f431
c6c9333
43e7396
ec5d8fc
f26a254
a86c946
3652fe6
92b4ffc
f821e09
5903001
869aba2
ab3bedf
31b3964
a489897
0073dcc
10f3448
b81c23b
00b272f
89a72f1
52af8ed
95f0a45
5c0a447
c384529
ee3f51d
c01c531
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ inst/doc | |
/doc/ | ||
/Meta/ | ||
CRAN-SUBMISSION | ||
benchmarks/data |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,12 @@ Authors@R: | |
family = "Pfisterer", | ||
role = "ctb", | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-8867-762X"))) | ||
comment = c(ORCID = "0000-0001-8867-762X")), | ||
person(given = "Carson", | ||
family = "Zhang", | ||
role = "ctb", | ||
email = "[email protected]") | ||
) | ||
Description: Deep Learning library that extends the mlr3 framework by building | ||
upon the 'torch' package. It allows to conveniently build, train, | ||
and evaluate deep learning models without having to worry about low level | ||
|
@@ -64,6 +69,7 @@ Suggests: | |
viridis, | ||
visNetwork, | ||
testthat (>= 3.0.0), | ||
tfevents, | ||
torchvision (>= 0.6.0), | ||
waldo | ||
Config/testthat/edition: 3 | ||
|
@@ -80,6 +86,7 @@ Collate: | |
'CallbackSetEarlyStopping.R' | ||
'CallbackSetHistory.R' | ||
'CallbackSetProgress.R' | ||
'CallbackSetTB.R' | ||
'ContextTorch.R' | ||
'DataBackendLazy.R' | ||
'utils.R' | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#' @title TensorBoard Logging Callback | ||
#' | ||
#' @name mlr_callback_set.tb | ||
#' | ||
#' @description | ||
#' Logs training loss, training measures, and validation measures as events. | ||
#' To view them, use TensorBoard with `tensorflow::tensorboard()` (requires `tensorflow`) or the CLI. | ||
#' @details | ||
#' Logs events at most every epoch. | ||
#' | ||
#' @param path (`character(1)`)\cr | ||
#' The path to a folder where the events are logged. | ||
#' Point TensorBoard to this folder to view them. | ||
#' @param log_train_loss (`logical(1)`)\cr | ||
#' Whether we log the training loss. | ||
#' @family Callback | ||
#' @export | ||
#' @include CallbackSet.R | ||
CallbackSetTB = R6Class("CallbackSetTB", | ||
inherit = CallbackSet, | ||
lock_objects = FALSE, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(path, log_train_loss) { | ||
self$path = assert_path_for_output(path) | ||
if (!dir.exists(path)) { | ||
dir.create(path, recursive = TRUE) | ||
} | ||
self$log_train_loss = assert_logical(log_train_loss) | ||
}, | ||
#' @description | ||
#' Logs the training loss, training measures, and validation measures as TensorFlow events. | ||
on_epoch_end = function() { | ||
if (self$log_train_loss) { | ||
private$.log_train_loss() | ||
} | ||
|
||
if (length(self$ctx$last_scores_train)) { | ||
map(names(self$ctx$measures_train), private$.log_train_score) | ||
} | ||
|
||
if (length(self$ctx$last_scores_valid)) { | ||
map(names(self$ctx$measures_valid), private$.log_valid_score) | ||
} | ||
} | ||
), | ||
private = list( | ||
.log_score = function(prefix, measure_name, score) { | ||
event_list = list(score, self$ctx$epoch) | ||
names(event_list) = c(paste0(prefix, measure_name), "step") | ||
|
||
with_logdir(self$path, { | ||
do.call(log_event, event_list) | ||
}) | ||
}, | ||
.log_valid_score = function(measure_name) { | ||
valid_score = self$ctx$last_scores_valid[[measure_name]] | ||
private$.log_score("valid.", measure_name, valid_score) | ||
}, | ||
.log_train_score = function(measure_name) { | ||
train_score = self$ctx$last_scores_train[[measure_name]] | ||
private$.log_score("train.", measure_name, train_score) | ||
}, | ||
.log_train_loss = function() { | ||
with_logdir(self$path, { | ||
log_event(train.loss = self$ctx$last_loss) | ||
}) | ||
} | ||
) | ||
) | ||
|
||
#' @include TorchCallback.R | ||
mlr3torch_callbacks$add("tb", function() { | ||
TorchCallback$new( | ||
callback_generator = CallbackSetTB, | ||
param_set = ps( | ||
path = p_uty(tags = c("train", "required")), | ||
log_train_loss = p_lgl(tags = c("train", "required")) | ||
), | ||
id = "tb", | ||
label = "TensorBoard", | ||
man = "mlr3torch::mlr_callback_set.tb" | ||
) | ||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
library(here) | ||
|
||
library(mlr3oml) | ||
library(data.table) | ||
library(tidytable) | ||
|
||
cc18_collection = ocl(99) | ||
|
||
cc18_simple = list_oml_data(data_id = cc18_collection$data_ids, | ||
number_classes = 2, | ||
number_missing_values = 0) | ||
|
||
cc18_small = cc18_simple |> | ||
filter(NumberOfSymbolicFeatures == 1) |> | ||
select(data_id, name, NumberOfFeatures, NumberOfInstances) |> | ||
filter(name %in% c("qsar-biodeg", "madelon", "kc1", "blood-transfusion-service-center", "climate-model-simulation-crashes")) | ||
|
||
# kc1_1067 = odt(1067) | ||
|
||
|
||
# save the data locally | ||
mlr3misc::pmap(cc18_small, function(data_id, name, NumberOfFeatures, NumberOfInstances) { | ||
dt = odt(data_id)$data | ||
dt_name = here("data", "oml", paste0(name, "_", data_id, ".csv")) | ||
fwrite(dt, file = dt_name) | ||
}) | ||
|
||
fwrite(cc18_small, here("data", "oml", "cc18_small.csv")) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,105 @@ | ||||||||||
library(mlr3) | ||||||||||
library(data.table) | ||||||||||
library(mlr3torch) | ||||||||||
library(paradox) | ||||||||||
|
||||||||||
library(here) | ||||||||||
|
||||||||||
# define the tasks | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this will ignore the OpenML resampling |
||||||||||
cc18_small = fread(here("data", "oml", "cc18_small.csv")) | ||||||||||
cc18_small_datasets = mlr3misc::pmap(cc18_small, function(data_id, name, NumberOfFeatures, NumberOfInstances) { | ||||||||||
dt_name = here("data", "oml", paste0(name, "_", data_id, ".csv")) | ||||||||||
fread(dt_name) | ||||||||||
}) | ||||||||||
# cc18_small_datasets | ||||||||||
|
||||||||||
# cc18_small_datasets[[1]] | ||||||||||
|
||||||||||
# TODO: determine whether we can use OML tasks "directly" | ||||||||||
# didn't do this at first because they come with resamplings and we want to use our own resamplings | ||||||||||
kc1_1067 = as_task_classif(cc18_small_datasets[[1]], target = "defects") | ||||||||||
blood_1464 = as_task_classif(cc18_small_datasets[[2]], target = "Class") | ||||||||||
|
||||||||||
tasks = list(kc1_1067, blood_1464) | ||||||||||
|
||||||||||
# define the learners | ||||||||||
mlp = lrn("classif.mlp", | ||||||||||
activation = nn_relu, | ||||||||||
neurons = to_tune( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this won't work because the libraries don't allow parameter transformations. When I try to run the experiment I get this error:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My solution for now:
|
||||||||||
c( | ||||||||||
10, 20, | ||||||||||
c(10, 10), c(10, 20), c(20, 10), c(20, 20) | ||||||||||
) | ||||||||||
), | ||||||||||
batch_size = to_tune(16, 32, 64), | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. go bigger
Suggested change
|
||||||||||
p = to_tune(0.1, 0.9), | ||||||||||
epochs = to_tune(upper = 1000L, internal = TRUE), | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider reducing the max number of epochs |
||||||||||
validate = 0.3, | ||||||||||
measures_valid = msr("classif.acc"), | ||||||||||
patience = 10, | ||||||||||
device = "cpu" | ||||||||||
) | ||||||||||
|
||||||||||
# define an AutoTuner that wraps the classif.mlp | ||||||||||
at = auto_tuner( | ||||||||||
learner = mlp, | ||||||||||
tuner = tnr("grid_search"), | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use MBO: more efficient |
||||||||||
resampling = rsmp("cv"), | ||||||||||
measure = msr("clasif.acc"), | ||||||||||
cxzhang4 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
term_evals = 10 | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likely need more than 10
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run on GPU server (but without all cores) |
||||||||||
) | ||||||||||
|
||||||||||
future::plan("multisession") | ||||||||||
|
||||||||||
design = benchmark_grid( | ||||||||||
tasks, | ||||||||||
learners = list(at, lrn("classif.ranger"), | ||||||||||
resampling = rsmp("cv", folds = 10)) | ||||||||||
) | ||||||||||
|
||||||||||
bmr = benchmark(design) | ||||||||||
|
||||||||||
bmrdt = as.data.table(bmr) | ||||||||||
|
||||||||||
fwrite(bmrdt, here("R", "rf_Use_case", "results", "bmrdt.csv")) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo
Suggested change
|
||||||||||
|
||||||||||
# define an optimization strategy: grid search | ||||||||||
|
||||||||||
# define a search space: the parameters to tune over | ||||||||||
|
||||||||||
# neurons | ||||||||||
|
||||||||||
# batch size | ||||||||||
|
||||||||||
# dropout rate | ||||||||||
|
||||||||||
# epochs | ||||||||||
|
||||||||||
# use something standard (e.g. accuracy) as the tuning measure | ||||||||||
|
||||||||||
# use k-fold cross validation | ||||||||||
|
||||||||||
# set a number of evaluations for the tuner | ||||||||||
|
||||||||||
# TODO: set up the tuning space for the neurons and layers | ||||||||||
|
||||||||||
# layers_search_space <- 1:5 | ||||||||||
# neurons_search_space <- seq(10, 50, by = 10) | ||||||||||
|
||||||||||
# generate_permutations <- function(layers_search_space, neurons_search_space) { | ||||||||||
# result <- list() | ||||||||||
|
||||||||||
# for (layers in layers_search_space) { | ||||||||||
# # Generate all permutations with replacement | ||||||||||
# perms <- expand.grid(replicate(layers, list(neurons_search_space), simplify = FALSE)) | ||||||||||
|
||||||||||
# # Convert each row to a vector and add to the result | ||||||||||
# result <- c(result, apply(perms, 1, as.numeric)) | ||||||||||
# } | ||||||||||
|
||||||||||
# return(result) | ||||||||||
# } | ||||||||||
|
||||||||||
# permutations <- generate_permutations(layers_search_space, neurons_search_space) | ||||||||||
|
||||||||||
# head(permutations) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
library(data.table) | ||
library(mlr3) | ||
|
||
library(here) | ||
|
||
bmrdt = fread(here("R", "rf_Use_case", "results", "bmrdt.csv")) | ||
|
||
bmrdt |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also add this to your
.Rprofile