You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: small_grants.md
+30Lines changed: 30 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -347,6 +347,36 @@ development skills and test-driven development of a large code base is required.
347
347
348
348
**Reviewers**: Chris Rackauckas
349
349
350
+
## Add support for TabM architecture in NeuroTabModels.jl and remove Zygote.jl dependency (\$1800)
351
+
352
+
[NeuroTabModels.jl](https://github.com/Evovest/NeuroTabModels.jl) is a library for training neural networks on tabular data. It currently supports a limited set of architectures: MLP, ResNets and NeuroTrees, and is built on top of Flux.jl and Zygote.jl.
353
+
354
+
The objective of this project is to set a better foundation for the library by moving from Zygote to Enzyme to benefit from improved performance through Reactant and demonstrating ease of extension by adding support for a new architecture type, TabM.
355
+
356
+
**Information to Get Started**:
357
+
- TabM paper:
358
+
- Official implementation: https://github.com/yandex-research/tabm
359
+
- Paper: https://arxiv.org/abs/2410.24210
360
+
- Numerical embeddings: are a dependency for TabM, but can be useful for any other models. The objective is thus to implement them as new neural operators accessible to any model. Implementation to follow the Yandex reference:
- A Numerical Embeddings module added as a preprocessing layer accessible to all models (TabM, MLP, NeuroTrees...)
368
+
- Removal of the dependency on Zygote.jl in favor of Enzyme.jl for automatic differentiation.
369
+
This notably involves handling/replacing the currently existing custom rules used for `leaf_weights` in NeuroTrees. See https://enzyme.mit.edu/julia/stable/#Importing-ChainRules.
370
+
- Performance comparison with the original TabM implementation.
371
+
- Correctness of the implementation verified by assessing similarity of the predictions with original implementation.
372
+
- Documentation of the model and minimal tests within the package test suite.
373
+
374
+
It's also expected that TabM model will be assessed against basic regression benchmarks on [MLBenchmarks.jl](https://github.com/Evovest/MLBenchmarks.jl/tree/openml) on `year` and `msrank` datasets.
375
+
376
+
**Recommended Skills**: Familiarity with deep learning frameworks such as Flux.jl or Lux.jl and underlying autodiff systems (Enzyme.jl).
0 commit comments