Skip to content

Commit fe38cc0

Browse files
authored
new: Implement ModernBERT (#7)
* new: RoPE implementation and tests * make: apply linting with ruff * docs: add link to reference implementation for RoPE * make: fix type checking with mypy * new: add embedding and MLP layers for modernbert * new: add ModernBERTAttention * tests: add numerical tests for ModernBERTAttention, apply linting with Ruff * make: linting and proper type annotation * new: add ModerBERTEncoderLayer * make: fix type hints * make: better type annotations * new: add full ModernBERT model implementation * docs: update docstrings for ModernBERT * make: linting and type annotations * make: rm unnecessary SwiGLU * docs: include modernbert to docs
1 parent 36a4749 commit fe38cc0

File tree

5 files changed

+2339
-0
lines changed

5 files changed

+2339
-0
lines changed

docs/modules/index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ API Documentation
66

77
attention
88
functional
9+
models
910

1011
Attention Modules
1112
---------------
@@ -19,6 +20,14 @@ Functional Interfaces
1920
------------------
2021

2122
.. automodule:: jax_layers.functional
23+
:members:
24+
:undoc-members:
25+
:show-inheritance:
26+
27+
Models
28+
------------------
29+
30+
.. automodule:: jax_layers.models
2231
:members:
2332
:undoc-members:
2433
:show-inheritance:

jax_layers/models/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from jax_layers.models.modernbert import (
2+
ModernBertAttention,
3+
ModernBertEmbeddings,
4+
ModernBERTEncoder,
5+
ModernBERTForMaskedLM,
6+
ModernBertLayer,
7+
ModernBertMLP,
8+
)
9+
10+
__all__ = [
11+
"ModernBERTEncoder",
12+
"ModernBERTForMaskedLM",
13+
"ModernBertAttention",
14+
"ModernBertEmbeddings",
15+
"ModernBertLayer",
16+
"ModernBertMLP",
17+
]

0 commit comments

Comments
 (0)