Skip to content

Commit a5af24c

Browse files
author
Flax Authors
committed
Merge pull request #5094 from thijs-vanweezel:filters
PiperOrigin-RevId: 839359310
2 parents 6ac5a78 + a7bb109 commit a5af24c

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

flax/nnx/filterlib.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@ def __repr__(self):
8989

9090
@dataclasses.dataclass(frozen=True)
9191
class PathContains:
92-
key: Key
92+
key: Key | str
93+
exact: bool = True
9394

9495
def __call__(self, path: PathParts, x: tp.Any):
95-
return self.key in path
96+
if self.exact:
97+
return self.key in path
98+
return any(str(self.key) in str(part) for part in path)
9699

97100
def __repr__(self):
98-
return f'PathContains({self.key!r})'
101+
return f'PathContains({self.key!r}, exact={self.exact})'
99102

100103

101104
class PathIn:

tests/nnx/filters_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@ class TestFilters(absltest.TestCase):
2121
def test_path_contains(self):
2222
class Model(nnx.Module):
2323
def __init__(self, rngs):
24-
self.backbone = nnx.Linear(2, 3, rngs=rngs)
24+
self.backbone1 = nnx.Linear(2, 3, rngs=rngs)
25+
self.backbone2 = nnx.Linear(3, 3, rngs=rngs)
2526
self.head = nnx.Linear(3, 10, rngs=rngs)
2627

2728
model = Model(nnx.Rngs(0))
2829

2930
head_state = nnx.state(model, nnx.PathContains('head'))
31+
backbones_state = nnx.state(model, nnx.PathContains('backbone', exact=False))
3032

3133
self.assertIn('head', head_state)
3234
self.assertNotIn('backbone', head_state)
35+
self.assertIn('backbone1', backbones_state)
36+
self.assertIn('backbone2', backbones_state)
37+
self.assertNotIn('head', backbones_state)
3338

3439
if __name__ == '__main__':
3540
absltest.main()

0 commit comments

Comments
 (0)