Skip to content

feat/pipeop-transformer-layer #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 67 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
50bc8e9
double quotes
cxzhang4 Apr 15, 2025
8aa4470
style
cxzhang4 Apr 15, 2025
4b5fafe
copied in old attic code to test file, still need to try
cxzhang4 Apr 15, 2025
258ea42
idrk
cxzhang4 Apr 21, 2025
a8a8787
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 Apr 21, 2025
b81e9dd
changed d_token in test
cxzhang4 Apr 21, 2025
e7288f8
small cleanup
cxzhang4 Apr 21, 2025
da059e2
cleanup
cxzhang4 Apr 21, 2025
866a7ec
factored out d_token
cxzhang4 Apr 21, 2025
e6e67f6
idk
cxzhang4 Apr 21, 2025
9d37216
test passes for now
cxzhang4 Apr 22, 2025
31ef199
intermrediate docs
cxzhang4 Apr 22, 2025
ebcb3c0
TODO: implement custom checks for parameters that are nn_modules or n…
cxzhang4 Apr 22, 2025
946dba0
more TODOs
cxzhang4 Apr 22, 2025
00c91eb
docs
cxzhang4 Apr 22, 2025
da21ff2
change title of nn_ft_transformer_layer module
cxzhang4 Apr 22, 2025
7d65f09
removed is_first_layer param
cxzhang4 Apr 22, 2025
3872ce0
some comments
cxzhang4 Apr 22, 2025
ce4809b
a comment
cxzhang4 Apr 23, 2025
41724f0
added back is_first_layer param
cxzhang4 Apr 23, 2025
1c2ee1e
added back comment on prenormalization condition
cxzhang4 Apr 23, 2025
6b4c34a
comment on parameters
cxzhang4 Apr 24, 2025
24645c4
Merge branch 'main' into feat/pipeop-transformer-layer
sebffischer Apr 24, 2025
8b671ee
some changes
sebffischer Apr 24, 2025
f812cb5
some notes
sebffischer Apr 24, 2025
14c98c5
some more changes
sebffischer Apr 25, 2025
8d5f641
...
sebffischer Apr 25, 2025
c479d9c
factored out last_layer_query_idx from layer
cxzhang4 Apr 25, 2025
0c020cf
query_idx should be -1L (last dim) for last transformer layer
cxzhang4 Apr 25, 2025
0f17330
deleted file with old name (Layer, not Block)
cxzhang4 Apr 25, 2025
916648d
formatting
cxzhang4 Apr 25, 2025
40d5e44
check_nn_module_generator
cxzhang4 Apr 25, 2025
44f003b
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 Apr 25, 2025
726e8af
got rid of some defaults
cxzhang4 Apr 26, 2025
5b12f99
idk
cxzhang4 Apr 27, 2025
c69af36
reduce blocks in learner when there are multipleg
cxzhang4 Apr 27, 2025
8f80d18
delete TODO
cxzhang4 Apr 27, 2025
d0ec4fc
fix test
cxzhang4 Apr 27, 2025
f6a6326
some comments
cxzhang4 Apr 27, 2025
a11b217
small changes
Apr 28, 2025
0416ead
Merge branch 'main' into feat/pipeop-transformer-layer
Apr 28, 2025
3e609bd
added custom error messages
cxzhang4 Apr 28, 2025
852d69c
x_residual
cxzhang4 Apr 29, 2025
ec4e8ab
set block dependent default vals
cxzhang4 Apr 29, 2025
d5c3cc4
some comments
cxzhang4 Apr 29, 2025
e73100a
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 Apr 29, 2025
92ced45
intermediate
cxzhang4 May 1, 2025
7a23834
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 May 1, 2025
31571ff
remove first_prenormalization
cxzhang4 May 1, 2025
b5172a6
address TODOs
cxzhang4 May 1, 2025
17d25cb
looks ok 2 me, still has compression
cxzhang4 May 1, 2025
d50e1e0
removed kv compression
cxzhang4 May 1, 2025
c9ad1e9
update docs for learner
cxzhang4 May 1, 2025
5bf9ada
added block-dependent defaults, removed required tags from learner pa…
cxzhang4 May 1, 2025
7707922
added activation-dependent ffn_d_hidden
cxzhang4 May 1, 2025
b62f95c
a comment
cxzhang4 May 1, 2025
abb2094
man
cxzhang4 May 2, 2025
173b72c
defaults look ok
cxzhang4 May 5, 2025
85b1d71
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 May 8, 2025
5ecef42
removed comment
cxzhang4 May 9, 2025
b977f56
removed test for mixed input:
cxzhang4 May 9, 2025
552c6d1
removed browser())
cxzhang4 May 9, 2025
b98ab8f
removed print statemetn from test:
cxzhang4 May 9, 2025
ac110dd
add test for only categorical input
cxzhang4 May 9, 2025
358ef1c
simple script for benchmarking ft transformer
cxzhang4 May 11, 2025
0ac5b92
init -> default, shouldn't have duplicated anyway
cxzhang4 May 12, 2025
c2449c9
added query_idx = NULL in the constructor in the module generator
cxzhang4 May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Collate:
'DataBackendLazy.R'
'utils.R'
'DataDescriptor.R'
'LearnerFTTransformer.R'
'LearnerTorch.R'
'LearnerTorchFeatureless.R'
'LearnerTorchImage.R'
Expand All @@ -118,6 +119,7 @@ Collate:
'PipeOpTorchConvTranspose.R'
'PipeOpTorchDropout.R'
'PipeOpTorchFTCLS.R'
'PipeOpTorchFTTransformerBlock.R'
'PipeOpTorchFn.R'
'PipeOpTorchHead.R'
'PipeOpTorchIdentity.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export(ContextTorch)
export(DataBackendLazy)
export(DataDescriptor)
export(LearnerTorch)
export(LearnerTorchFTTransformer)
export(LearnerTorchFeatureless)
export(LearnerTorchImage)
export(LearnerTorchMLP)
Expand Down Expand Up @@ -105,6 +106,7 @@ export(PipeOpTorchConvTranspose3D)
export(PipeOpTorchDropout)
export(PipeOpTorchELU)
export(PipeOpTorchFTCLS)
export(PipeOpTorchFTTransformerBlock)
export(PipeOpTorchFlatten)
export(PipeOpTorchFn)
export(PipeOpTorchGELU)
Expand Down Expand Up @@ -187,6 +189,7 @@ export(model_descriptor_to_module)
export(model_descriptor_union)
export(nn)
export(nn_ft_cls)
export(nn_ft_transformer_block)
export(nn_geglu)
export(nn_graph)
export(nn_merge_cat)
Expand Down
220 changes: 220 additions & 0 deletions R/LearnerFTTransformer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@

#' @title FT-Transformer
#' @templateVar name ft_transformer
#' @templateVar task_types classif, regr
#' @templateVar param_vals n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3
#' @template params_learner
#' @template learner
#' @template learner_example
#'
#' @description
#' Feature-Tokenizer Transformer for tabular data that can either work on [`lazy_tensor`] inputs
#' or on standard tabular features.
#'
#' Some differences from the paper implementation: no attention compression, no prenormalization in the first layer.
#'
#' @section Parameters:
#' Parameters from [`LearnerTorch`], as well as:
#' * `n_blocks` :: `integer(1)`\cr
#' The number of transformer blocks.
#' * `d_token` :: `integer(1)`\cr
#' The dimension of the embedding.
#' * `cardinalities` :: `integer(1)`\cr
#' The number of categories for each feature.
#' * `init_token` :: `character(1)`\cr
#' The initialization method for the embedding weights. Either "uniform" or "normal".
#' * `ingress_tokens` :: `numeric(1)`\cr
#' A list of `TorchIngressToken`s.
#'
#' @references
#' `r format_bib("gorishniy2021revisiting")`
#' @export
LearnerTorchFTTransformer = R6Class("LearnerTorchFTTransformer",
inherit = LearnerTorch,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
private$.block = PipeOpTorchFTTransformerBlock$new()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check that the default arguments are as defined in the paper

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


check_ingress_tokens = crate(function(ingress_tokens, task) {
if (is.null(ingress_tokens)) {
return(TRUE)
}
msg = check_list(ingress_tokens, types = "TorchIngressToken", min.len = 1L, names = "unique")
if (!isTRUE(msg)) {
return(msg)
}
check_subset(names(ingress_tokens), c("num.input", "categ.input"))
})

private$.param_set_base = ps(
n_blocks = p_int(lower = 0, default = 3, tags = "train"),
d_token = p_int(lower = 1L, default = 192L, tags = "train"),
cardinalities = p_int(lower = 1L, tags = "train"),
init_token = p_fct(init = "uniform", levels = c("uniform", "normal"), tags = "train"),
ingress_tokens = p_uty(tags = "train", custom_check = check_ingress_tokens)
)
param_set = alist(private$.block$param_set, private$.param_set_base)

super$initialize(
task_type = task_type,
id = paste0(task_type, ".ft_transformer"),
label = "FT-Transformer",
param_set = param_set,
optimizer = optimizer,
callbacks = callbacks,
loss = loss,
man = "mlr3torch::mlr_learners.ft_transformer",
feature_types = c("numeric", "integer", "logical", "factor", "ordered", "lazy_tensor"),
# Because the CLS token does resizing that depends dynamically on the input shape,
# specifically, the batch size
jittable = FALSE
)
}
),
private = list(
.block = NULL,
.ingress_tokens = function(task, param_vals) {
if ("lazy_tensor" %in% task$feature_types$type) {
if (!all(task$feature_types$type == "lazy_tensor")) {
stopf("Learner '%s' received an input task '%s' that is mixing lazy_tensors with other feature types.", self$id, task$id) # nolint
}
if (task$n_features > 2L) {
stopf("Learner '%s' received an input task '%s' that has more than two lazy tensors.", self$id, task$id) # nolint
}
if (is.null(param_vals$ingress_tokens)) {
stopf("Learner '%s' received an input task '%s' with lazy tensors, but no parameter 'ingress_tokens' was specified.", self$id, task$id) # nolint
}

ingress_tokens = param_vals$ingress_tokens
row = task$head(1L)
for (i in seq_along(ingress_tokens)) {
feat = ingress_tokens[[i]]$features(task)
if (!length(feat) == 1L) {
stopf("Learner '%s' received an input task '%s' with lazy tensors, but the ingress token '%s' does not select exactly one feature.", self$id, task$id, names(ingress_tokens)[[i]]) # nolint
}
if (is.null(ingress_tokens[[i]]$shape)) {
ingress_tokens[[i]]$shape = lazy_shape(row[[feat]])
}
if (is.null(ingress_tokens[[i]]$shape)) {
stopf("Learner '%s' received an input task '%s' with lazy tensors, but neither the ingress token for '%s', nor the 'lazy_tensor' specify the shape, which makes it impossible to build the network.", self$id, task$id, feat) # nolint
}
}
return(ingress_tokens)
}
num_features = n_num_features(task)
categ_features = n_categ_features(task)
output = list()
if (num_features > 0L) {
output$num.input = ingress_num(shape = c(NA, num_features))
}
if (categ_features > 0L) {
output$categ.input = ingress_categ(shape = c(NA, categ_features))
}
output
},
.network = function(task, param_vals) {
its = private$.ingress_tokens(task, param_vals)
mds = list()
path_num = if (!is.null(its$num.input)) {
mds$tokenizer_num.input = ModelDescriptor(
po("nop", id = "num"),
its["num.input"],
task$clone(deep = TRUE)$select(its[["num.input"]]$features(task)),
pointer = c("num", "output"),
pointer_shape = its[["num.input"]]$shape
)
nn("tokenizer_num",
d_token = param_vals$d_token,
bias = TRUE,
initialization = param_vals$init_token
)
}
path_categ = if (!is.null(its$categ.input)) {
mds$tokenizer_categ.input = ModelDescriptor(
po("nop", id = "categ"),
its["categ.input"],
task$clone(deep = TRUE)$select(its[["categ.input"]]$features(task)),
pointer = c("categ", "output"),
pointer_shape = its[["categ.input"]]$shape
)
nn("tokenizer_categ",
d_token = param_vals$d_token,
bias = TRUE,
initialization = param_vals$init_token,
param_vals = discard(param_vals["cardinalities"], is.null)
)
}

input_paths = discard(list(path_num, path_categ), is.null)

graph_tokenizer = if (length(input_paths) == 1L) {
input_paths[[1L]]
} else {
gunion(input_paths) %>>%
nn("merge_cat", param_vals = list(dim = 2))
}

# heuristically defined default parameters that depend on the number of blocks
block_dependent_params = c("d_token", "attention_dropout", "ffn_dropout")
block_dependent_defaults = list(
d_token = c(96, 128, 192, 256, 320, 384),
attention_dropout = c(0.1, 0.15, 0.2, 0.25, 0.3, 0.35),
ffn_dropout = c(0.0, 0.05, 0.1, 0.15, 0.2, 0.25)
)
if (param_vals$n_blocks >= 1 && param_vals$n_blocks <= 6) {
null_block_dependent_params_idx = map_lgl(block_dependent_params, function(param_name) {
is.null(param_vals[[param_name]])
})
null_block_dependent_params = block_dependent_params[null_block_dependent_params_idx]

map(null_block_dependent_params, function(param_name) {
private$.block$param_set$values[[param_name]] = block_dependent_defaults[[param_name]][param_vals$n_blocks]
})
}

if (is.null(param_vals$ffn_d_hidden)) {
if (class(param_vals$ffn_activation)[1] %in% c("nn_reglu", "nn_geglu")) {
private$.block$param_set$values$ffn_d_hidden = 4 / 3
} else {
private$.block$param_set$values$ffn_d_hidden = 2.0
}
}

blocks = map(seq_len(param_vals$n_blocks), function(i) {
block = private$.block$clone(deep = TRUE)
block$id = sprintf("block_%i", i)

if (i == 1) {
block$param_set$values$is_first_layer = TRUE
} else {
block$param_set$values$is_first_layer = FALSE
}
if (i == param_vals$n_blocks) {
block$param_set$values$query_idx = -1L
} else {
block$param_set$values$query_idx = NULL
}
block
})

if (length(blocks) > 1L) {
blocks = Reduce(`%>>%`, blocks)
}

graph = graph_tokenizer %>>%
nn("ft_cls", initialization = "uniform") %>>%
blocks %>>%
nn("fn", fn = function(x) x[, -1]) %>>%
nn("layer_norm", dims = 1) %>>%
nn("relu") %>>%
nn("head")

model_descriptor_to_module(graph$train(mds, FALSE)[[1L]])
}
)
)

register_learner("regr.ft_transformer", LearnerTorchFTTransformer)
register_learner("classif.ft_transformer", LearnerTorchFTTransformer)
12 changes: 9 additions & 3 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
#' @param callbacks (`list()` of [`TorchCallback`]s)\cr
#' The callbacks to use for training.
#' Defaults to an empty` list()`, i.e. no callbacks.
#' @param jittable (`logical(1)`)\cr
#' Whether the model can be jit-traced. Default is `FALSE`.
#'
#' @section Model:
#' The Model is a list of class `"learner_torch_model"` with the following elements:
Expand Down Expand Up @@ -155,8 +157,9 @@ LearnerTorch = R6Class("LearnerTorch",
inherit = Learner,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, task_type, param_set, properties, man, label, feature_types,
optimizer = NULL, loss = NULL, packages = character(), predict_types = NULL, callbacks = list()) {
initialize = function(id, task_type, param_set, properties = character(), man, label, feature_types,
optimizer = NULL, loss = NULL, packages = character(), predict_types = NULL, callbacks = list(),
jittable = FALSE) {
assert_choice(task_type, c("regr", "classif"))

predict_types = predict_types %??% switch(task_type,
Expand All @@ -166,11 +169,14 @@ LearnerTorch = R6Class("LearnerTorch",

assert_subset(properties, mlr_reflections$learner_properties[[task_type]])
properties = union(properties, c("marshal", "validation", "internal_tuning"))
if (task_type == "classif") {
properties = union(properties, c("twoclass", "multiclass"))
}
assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]))
packages = assert_character(packages, any.missing = FALSE, min.chars = 1L)
packages = union(c("mlr3", "mlr3torch"), packages)

private$.param_set_torch = paramset_torchlearner(task_type)
private$.param_set_torch = paramset_torchlearner(task_type, jittable = jittable)

check_ps = function(param_set) {
assert_param_set(param_set)
Expand Down
6 changes: 1 addition & 5 deletions R/LearnerTorchFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
properties = switch(task_type,
classif = c("twoclass", "multiclass", "missings", "featureless", "marshal"),
regr = c("missings", "featureless", "marshal")
)
super$initialize(
id = paste0(task_type, ".torch_featureless"),
task_type = task_type,
label = "Featureless Torch Learner",
param_set = ps(),
properties = properties,
properties = c("missings", "featureless"),
feature_types = unname(mlr_reflections$task_feature_types),
man = "mlr3torch::mlr_learners.torch_featureless",
optimizer = optimizer,
Expand Down
13 changes: 5 additions & 8 deletions R/LearnerTorchImage.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#' @template param_properties
#' @template param_label
#' @template param_predict_types
#' @param jittable (`logical(1)`)\cr
#' Whether the model can be jit-traced.
#'
#' @section Parameters:
#' Parameters include those inherited from [`LearnerTorch`] and the `param_set` construction argument.
Expand All @@ -31,25 +33,20 @@ LearnerTorchImage = R6Class("LearnerTorchImage",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, task_type, param_set = ps(), label, optimizer = NULL, loss = NULL,
callbacks = list(), packages, man, properties = NULL,
predict_types = NULL) {
properties = properties %??% switch(task_type,
regr = c(),
classif = c("twoclass", "multiclass")
)
callbacks = list(), packages, man, properties = NULL, predict_types = NULL, jittable = FALSE) {
super$initialize(
id = id,
task_type = task_type,
label = label,
optimizer = optimizer,
properties = properties,
loss = loss,
param_set = param_set,
packages = packages,
callbacks = callbacks,
predict_types = predict_types,
feature_types = "lazy_tensor",
man = man
man = man,
jittable = jittable
)
}
),
Expand Down
8 changes: 2 additions & 6 deletions R/LearnerTorchMLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,18 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",
neurons = integer(0),
p = 0.5
)
properties = switch(task_type,
regr = character(0),
classif = c("twoclass", "multiclass")
)

super$initialize(
task_type = task_type,
id = paste0(task_type, ".mlp"),
properties = properties,
label = "Multi Layer Perceptron",
param_set = param_set,
optimizer = optimizer,
callbacks = callbacks,
loss = loss,
man = "mlr3torch::mlr_learners.mlp",
feature_types = c("numeric", "integer", "lazy_tensor")
feature_types = c("numeric", "integer", "lazy_tensor"),
jittable = TRUE
)
}
),
Expand Down
Loading