File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed
Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -89,13 +89,16 @@ def __repr__(self):
8989
9090@dataclasses .dataclass (frozen = True )
9191class PathContains :
92- key : Key
92+ key : Key | str
93+ exact : bool = True
9394
9495 def __call__ (self , path : PathParts , x : tp .Any ):
95- return self .key in path
96+ if self .exact :
97+ return self .key in path
98+ return any (str (self .key ) in str (part ) for part in path )
9699
97100 def __repr__ (self ):
98- return f'PathContains({ self .key !r} )'
101+ return f'PathContains({ self .key !r} , exact= { self . exact } )'
99102
100103
101104class PathIn :
Original file line number Diff line number Diff line change @@ -21,15 +21,20 @@ class TestFilters(absltest.TestCase):
2121 def test_path_contains (self ):
2222 class Model (nnx .Module ):
2323 def __init__ (self , rngs ):
24- self .backbone = nnx .Linear (2 , 3 , rngs = rngs )
24+ self .backbone1 = nnx .Linear (2 , 3 , rngs = rngs )
25+ self .backbone2 = nnx .Linear (3 , 3 , rngs = rngs )
2526 self .head = nnx .Linear (3 , 10 , rngs = rngs )
2627
2728 model = Model (nnx .Rngs (0 ))
2829
2930 head_state = nnx .state (model , nnx .PathContains ('head' ))
31+ backbones_state = nnx .state (model , nnx .PathContains ('backbone' , exact = False ))
3032
3133 self .assertIn ('head' , head_state )
3234 self .assertNotIn ('backbone' , head_state )
35+ self .assertIn ('backbone1' , backbones_state )
36+ self .assertIn ('backbone2' , backbones_state )
37+ self .assertNotIn ('head' , backbones_state )
3338
3439if __name__ == '__main__' :
3540 absltest .main ()
You can’t perform that action at this time.
0 commit comments