11# Pydantic validation for modules spec
2- from pydantic import BaseModel , Field , RootModel , field_validator
2+ from pydantic import BaseModel , ConfigDict , Field , RootModel , field_validator
33from torch import nn
44
55
@@ -47,7 +47,7 @@ def build(self) -> nn.Module:
4747 return cls (** kwargs )
4848
4949
50- class SequentialSpec (RootModel ):
50+ class SequentialSpec (BaseModel ):
5151 """
5252 Sequential spec where key = Sequential and value = list of NNSpecs.
5353 E.g.
@@ -61,7 +61,9 @@ class SequentialSpec(RootModel):
6161 out_features: 10
6262 """
6363
64- root : dict [str , list [NNSpec ]] = Field (
64+ model_config = ConfigDict (extra = "forbid" )
65+
66+ Sequential : list [NNSpec ] = Field (
6567 description = "Sequential module spec where value is a list of NNSpec." ,
6668 examples = [
6769 {
@@ -74,20 +76,9 @@ class SequentialSpec(RootModel):
7476 ],
7577 )
7678
77- @field_validator ("root" , mode = "before" )
78- def is_single_key_dict (value : dict ) -> dict :
79- return NNSpec .is_single_key_dict (value )
80-
81- @field_validator ("root" , mode = "before" )
82- def key_is_sequential (value : dict ) -> dict :
83- assert (
84- next (iter (value )) == "Sequential"
85- ), "Key must be 'Sequential' if using SequentialSpec."
86- return value
87-
8879 def build (self ) -> nn .Sequential :
8980 """Build nn.Sequential from sequential spec."""
90- nn_specs = next ( iter ( self .root . values ()))
81+ nn_specs = self .Sequential
9182 return nn .Sequential (* [nn_spec .build () for nn_spec in nn_specs ])
9283
9384
@@ -138,7 +129,7 @@ class CompactValueSpec(BaseModel):
138129 )
139130
140131
141- class CompactSpec (RootModel ):
132+ class CompactSpec (BaseModel ):
142133 """
143134 Higher level compact spec that expands into Sequential spec. This is useful for architecture search.
144135 Compact spec has the format:
@@ -174,7 +165,9 @@ class CompactSpec(RootModel):
174165 p: 0.1
175166 """
176167
177- root : dict [str , CompactValueSpec ] = Field (
168+ model_config = ConfigDict (extra = "forbid" )
169+
170+ compact : CompactValueSpec = Field (
178171 description = "Higher level compact spec that expands into Sequential spec." ,
179172 examples = [
180173 {
@@ -201,17 +194,6 @@ class CompactSpec(RootModel):
201194 ],
202195 )
203196
204- @field_validator ("root" , mode = "before" )
205- def is_single_key_dict (value : dict ) -> dict :
206- return NNSpec .is_single_key_dict (value )
207-
208- @field_validator ("root" , mode = "before" )
209- def key_is_compact (value : dict ) -> dict :
210- assert (
211- next (iter (value )) == "compact"
212- ), "Key must be 'compact' if using CompactSpec."
213- return value
214-
215197 def __expand_spec (self , compact_layer : dict ) -> list [dict ]:
216198 class_name = compact_layer ["type" ]
217199 keys = compact_layer ["keys" ]
@@ -223,7 +205,7 @@ def __expand_spec(self, compact_layer: dict) -> list[dict]:
223205 return nn_specs
224206
225207 def expand_to_sequential_spec (self ) -> SequentialSpec :
226- compact_spec = next ( iter ( self .root . values ())) .model_dump ()
208+ compact_spec = self .compact .model_dump ()
227209 prelayer = compact_spec .get ("prelayer" )
228210 postlayer = compact_spec .get ("postlayer" )
229211 nn_specs = []
0 commit comments