-
-
Notifications
You must be signed in to change notification settings - Fork 8
feat/pipeop-transformer-layer #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
cxzhang4
wants to merge
67
commits into
main
Choose a base branch
from
feat/pipeop-transformer-layer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
50bc8e9
double quotes
cxzhang4 8aa4470
style
cxzhang4 4b5fafe
copied in old attic code to test file, still need to try
cxzhang4 258ea42
idrk
cxzhang4 a8a8787
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 b81e9dd
changed d_token in test
cxzhang4 e7288f8
small cleanup
cxzhang4 da059e2
cleanup
cxzhang4 866a7ec
factored out d_token
cxzhang4 e6e67f6
idk
cxzhang4 9d37216
test passes for now
cxzhang4 31ef199
intermrediate docs
cxzhang4 ebcb3c0
TODO: implement custom checks for parameters that are nn_modules or n…
cxzhang4 946dba0
more TODOs
cxzhang4 00c91eb
docs
cxzhang4 da21ff2
change title of nn_ft_transformer_layer module
cxzhang4 7d65f09
removed is_first_layer param
cxzhang4 3872ce0
some comments
cxzhang4 ce4809b
a comment
cxzhang4 41724f0
added back is_first_layer param
cxzhang4 1c2ee1e
added back comment on prenormalization condition
cxzhang4 6b4c34a
comment on parameters
cxzhang4 24645c4
Merge branch 'main' into feat/pipeop-transformer-layer
sebffischer 8b671ee
some changes
sebffischer f812cb5
some notes
sebffischer 14c98c5
some more changes
sebffischer 8d5f641
...
sebffischer c479d9c
factored out last_layer_query_idx from layer
cxzhang4 0c020cf
query_idx should be -1L (last dim) for last transformer layer
cxzhang4 0f17330
deleted file with old name (Layer, not Block)
cxzhang4 916648d
formatting
cxzhang4 40d5e44
check_nn_module_generator
cxzhang4 44f003b
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 726e8af
got rid of some defaults
cxzhang4 5b12f99
idk
cxzhang4 c69af36
reduce blocks in learner when there are multipleg
cxzhang4 8f80d18
delete TODO
cxzhang4 d0ec4fc
fix test
cxzhang4 f6a6326
some comments
cxzhang4 a11b217
small changes
0416ead
Merge branch 'main' into feat/pipeop-transformer-layer
3e609bd
added custom error messages
cxzhang4 852d69c
x_residual
cxzhang4 ec4e8ab
set block dependent default vals
cxzhang4 d5c3cc4
some comments
cxzhang4 e73100a
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 92ced45
intermediate
cxzhang4 7a23834
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 31571ff
remove first_prenormalization
cxzhang4 b5172a6
address TODOs
cxzhang4 17d25cb
looks ok 2 me, still has compression
cxzhang4 d50e1e0
removed kv compression
cxzhang4 c9ad1e9
update docs for learner
cxzhang4 5bf9ada
added block-dependent defaults, removed required tags from learner pa…
cxzhang4 7707922
added activation-dependent ffn_d_hidden
cxzhang4 b62f95c
a comment
cxzhang4 abb2094
man
cxzhang4 173b72c
defaults look ok
cxzhang4 85b1d71
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 5ecef42
removed comment
cxzhang4 b977f56
removed test for mixed input:
cxzhang4 552c6d1
removed browser())
cxzhang4 b98ab8f
removed print statemetn from test:
cxzhang4 ac110dd
add test for only categorical input
cxzhang4 358ef1c
simple script for benchmarking ft transformer
cxzhang4 0ac5b92
init -> default, shouldn't have duplicated anyway
cxzhang4 c2449c9
added query_idx = NULL in the constructor in the module generator
cxzhang4 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
|
||
#' @title FT-Transformer | ||
#' @templateVar name ft_transformer | ||
#' @templateVar task_types classif, regr | ||
#' @templateVar param_vals n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3 | ||
#' @template params_learner | ||
#' @template learner | ||
#' @template learner_example | ||
#' | ||
#' @description | ||
#' Feature-Tokenizer Transformer for tabular data that can either work on [`lazy_tensor`] inputs | ||
#' or on standard tabular features. | ||
#' | ||
#' Some differences from the paper implementation: no attention compression, no prenormalization in the first layer. | ||
#' | ||
#' @section Parameters: | ||
#' Parameters from [`LearnerTorch`], as well as: | ||
#' * `n_blocks` :: `integer(1)`\cr | ||
#' The number of transformer blocks. | ||
#' * `d_token` :: `integer(1)`\cr | ||
#' The dimension of the embedding. | ||
#' * `cardinalities` :: `integer(1)`\cr | ||
#' The number of categories for each feature. | ||
#' * `init_token` :: `character(1)`\cr | ||
#' The initialization method for the embedding weights. Either "uniform" or "normal". | ||
#' * `ingress_tokens` :: `numeric(1)`\cr | ||
#' A list of `TorchIngressToken`s. | ||
#' | ||
#' @references | ||
#' `r format_bib("gorishniy2021revisiting")` | ||
#' @export | ||
LearnerTorchFTTransformer = R6Class("LearnerTorchFTTransformer", | ||
inherit = LearnerTorch, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) { | ||
private$.block = PipeOpTorchFTTransformerBlock$new() | ||
|
||
check_ingress_tokens = crate(function(ingress_tokens, task) { | ||
if (is.null(ingress_tokens)) { | ||
return(TRUE) | ||
} | ||
msg = check_list(ingress_tokens, types = "TorchIngressToken", min.len = 1L, names = "unique") | ||
if (!isTRUE(msg)) { | ||
return(msg) | ||
} | ||
check_subset(names(ingress_tokens), c("num.input", "categ.input")) | ||
}) | ||
|
||
private$.param_set_base = ps( | ||
n_blocks = p_int(lower = 0, default = 3, tags = "train"), | ||
d_token = p_int(lower = 1L, default = 192L, tags = "train"), | ||
cardinalities = p_int(lower = 1L, tags = "train"), | ||
init_token = p_fct(init = "uniform", levels = c("uniform", "normal"), tags = "train"), | ||
ingress_tokens = p_uty(tags = "train", custom_check = check_ingress_tokens) | ||
) | ||
param_set = alist(private$.block$param_set, private$.param_set_base) | ||
|
||
super$initialize( | ||
task_type = task_type, | ||
id = paste0(task_type, ".ft_transformer"), | ||
label = "FT-Transformer", | ||
param_set = param_set, | ||
optimizer = optimizer, | ||
callbacks = callbacks, | ||
loss = loss, | ||
man = "mlr3torch::mlr_learners.ft_transformer", | ||
feature_types = c("numeric", "integer", "logical", "factor", "ordered", "lazy_tensor"), | ||
# Because the CLS token does resizing that depends dynamically on the input shape, | ||
# specifically, the batch size | ||
jittable = FALSE | ||
) | ||
} | ||
), | ||
private = list( | ||
.block = NULL, | ||
.ingress_tokens = function(task, param_vals) { | ||
if ("lazy_tensor" %in% task$feature_types$type) { | ||
if (!all(task$feature_types$type == "lazy_tensor")) { | ||
stopf("Learner '%s' received an input task '%s' that is mixing lazy_tensors with other feature types.", self$id, task$id) # nolint | ||
} | ||
if (task$n_features > 2L) { | ||
stopf("Learner '%s' received an input task '%s' that has more than two lazy tensors.", self$id, task$id) # nolint | ||
} | ||
if (is.null(param_vals$ingress_tokens)) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but no parameter 'ingress_tokens' was specified.", self$id, task$id) # nolint | ||
} | ||
|
||
ingress_tokens = param_vals$ingress_tokens | ||
row = task$head(1L) | ||
for (i in seq_along(ingress_tokens)) { | ||
feat = ingress_tokens[[i]]$features(task) | ||
if (!length(feat) == 1L) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but the ingress token '%s' does not select exactly one feature.", self$id, task$id, names(ingress_tokens)[[i]]) # nolint | ||
} | ||
if (is.null(ingress_tokens[[i]]$shape)) { | ||
ingress_tokens[[i]]$shape = lazy_shape(row[[feat]]) | ||
} | ||
if (is.null(ingress_tokens[[i]]$shape)) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but neither the ingress token for '%s', nor the 'lazy_tensor' specify the shape, which makes it impossible to build the network.", self$id, task$id, feat) # nolint | ||
} | ||
} | ||
return(ingress_tokens) | ||
} | ||
num_features = n_num_features(task) | ||
categ_features = n_categ_features(task) | ||
output = list() | ||
if (num_features > 0L) { | ||
output$num.input = ingress_num(shape = c(NA, num_features)) | ||
} | ||
if (categ_features > 0L) { | ||
output$categ.input = ingress_categ(shape = c(NA, categ_features)) | ||
} | ||
output | ||
}, | ||
.network = function(task, param_vals) { | ||
its = private$.ingress_tokens(task, param_vals) | ||
mds = list() | ||
path_num = if (!is.null(its$num.input)) { | ||
mds$tokenizer_num.input = ModelDescriptor( | ||
po("nop", id = "num"), | ||
its["num.input"], | ||
task$clone(deep = TRUE)$select(its[["num.input"]]$features(task)), | ||
pointer = c("num", "output"), | ||
pointer_shape = its[["num.input"]]$shape | ||
) | ||
nn("tokenizer_num", | ||
d_token = param_vals$d_token, | ||
bias = TRUE, | ||
initialization = param_vals$init_token | ||
) | ||
} | ||
path_categ = if (!is.null(its$categ.input)) { | ||
mds$tokenizer_categ.input = ModelDescriptor( | ||
po("nop", id = "categ"), | ||
its["categ.input"], | ||
task$clone(deep = TRUE)$select(its[["categ.input"]]$features(task)), | ||
pointer = c("categ", "output"), | ||
pointer_shape = its[["categ.input"]]$shape | ||
) | ||
nn("tokenizer_categ", | ||
d_token = param_vals$d_token, | ||
bias = TRUE, | ||
initialization = param_vals$init_token, | ||
param_vals = discard(param_vals["cardinalities"], is.null) | ||
) | ||
} | ||
|
||
input_paths = discard(list(path_num, path_categ), is.null) | ||
|
||
graph_tokenizer = if (length(input_paths) == 1L) { | ||
input_paths[[1L]] | ||
} else { | ||
gunion(input_paths) %>>% | ||
nn("merge_cat", param_vals = list(dim = 2)) | ||
} | ||
|
||
# heuristically defined default parameters that depend on the number of blocks | ||
block_dependent_params = c("d_token", "attention_dropout", "ffn_dropout") | ||
block_dependent_defaults = list( | ||
d_token = c(96, 128, 192, 256, 320, 384), | ||
attention_dropout = c(0.1, 0.15, 0.2, 0.25, 0.3, 0.35), | ||
ffn_dropout = c(0.0, 0.05, 0.1, 0.15, 0.2, 0.25) | ||
) | ||
if (param_vals$n_blocks >= 1 && param_vals$n_blocks <= 6) { | ||
null_block_dependent_params_idx = map_lgl(block_dependent_params, function(param_name) { | ||
is.null(param_vals[[param_name]]) | ||
}) | ||
null_block_dependent_params = block_dependent_params[null_block_dependent_params_idx] | ||
|
||
map(null_block_dependent_params, function(param_name) { | ||
private$.block$param_set$values[[param_name]] = block_dependent_defaults[[param_name]][param_vals$n_blocks] | ||
}) | ||
} | ||
|
||
if (is.null(param_vals$ffn_d_hidden)) { | ||
if (class(param_vals$ffn_activation)[1] %in% c("nn_reglu", "nn_geglu")) { | ||
private$.block$param_set$values$ffn_d_hidden = 4 / 3 | ||
} else { | ||
private$.block$param_set$values$ffn_d_hidden = 2.0 | ||
} | ||
} | ||
|
||
blocks = map(seq_len(param_vals$n_blocks), function(i) { | ||
block = private$.block$clone(deep = TRUE) | ||
block$id = sprintf("block_%i", i) | ||
|
||
if (i == 1) { | ||
block$param_set$values$is_first_layer = TRUE | ||
} else { | ||
block$param_set$values$is_first_layer = FALSE | ||
} | ||
if (i == param_vals$n_blocks) { | ||
block$param_set$values$query_idx = -1L | ||
} else { | ||
block$param_set$values$query_idx = NULL | ||
} | ||
block | ||
}) | ||
|
||
if (length(blocks) > 1L) { | ||
blocks = Reduce(`%>>%`, blocks) | ||
} | ||
|
||
graph = graph_tokenizer %>>% | ||
nn("ft_cls", initialization = "uniform") %>>% | ||
blocks %>>% | ||
nn("fn", fn = function(x) x[, -1]) %>>% | ||
nn("layer_norm", dims = 1) %>>% | ||
nn("relu") %>>% | ||
nn("head") | ||
|
||
model_descriptor_to_module(graph$train(mds, FALSE)[[1L]]) | ||
} | ||
) | ||
) | ||
|
||
register_learner("regr.ft_transformer", LearnerTorchFTTransformer) | ||
register_learner("classif.ft_transformer", LearnerTorchFTTransformer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check that the default arguments are as defined in the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/yandex-research/rtdl-revisiting-models/blob/e3ed46cac38568785289d8fa16b8cfa585bde27e/package/rtdl_revisiting_models.py#L752-L778