Skip to content

Commit ff8a47a

Browse files
authored
Merge pull request #18 from kengz/refactor
refactor: use simpler BaseModel for Sequential, compact
2 parents 81fa559 + 5ddfad0 commit ff8a47a

File tree

3 files changed

+13
-31
lines changed

3 files changed

+13
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "torcharc"
3-
version = "2.1.1"
3+
version = "2.1.2"
44
description = "Build PyTorch models by specifying architectures."
55
readme = "README.md"
66
requires-python = ">=3.12"

torcharc/validator/modules.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Pydantic validation for modules spec
2-
from pydantic import BaseModel, Field, RootModel, field_validator
2+
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator
33
from torch import nn
44

55

@@ -47,7 +47,7 @@ def build(self) -> nn.Module:
4747
return cls(**kwargs)
4848

4949

50-
class SequentialSpec(RootModel):
50+
class SequentialSpec(BaseModel):
5151
"""
5252
Sequential spec where key = Sequential and value = list of NNSpecs.
5353
E.g.
@@ -61,7 +61,9 @@ class SequentialSpec(RootModel):
6161
out_features: 10
6262
"""
6363

64-
root: dict[str, list[NNSpec]] = Field(
64+
model_config = ConfigDict(extra="forbid")
65+
66+
Sequential: list[NNSpec] = Field(
6567
description="Sequential module spec where value is a list of NNSpec.",
6668
examples=[
6769
{
@@ -74,20 +76,9 @@ class SequentialSpec(RootModel):
7476
],
7577
)
7678

77-
@field_validator("root", mode="before")
78-
def is_single_key_dict(value: dict) -> dict:
79-
return NNSpec.is_single_key_dict(value)
80-
81-
@field_validator("root", mode="before")
82-
def key_is_sequential(value: dict) -> dict:
83-
assert (
84-
next(iter(value)) == "Sequential"
85-
), "Key must be 'Sequential' if using SequentialSpec."
86-
return value
87-
8879
def build(self) -> nn.Sequential:
8980
"""Build nn.Sequential from sequential spec."""
90-
nn_specs = next(iter(self.root.values()))
81+
nn_specs = self.Sequential
9182
return nn.Sequential(*[nn_spec.build() for nn_spec in nn_specs])
9283

9384

@@ -138,7 +129,7 @@ class CompactValueSpec(BaseModel):
138129
)
139130

140131

141-
class CompactSpec(RootModel):
132+
class CompactSpec(BaseModel):
142133
"""
143134
Higher level compact spec that expands into Sequential spec. This is useful for architecture search.
144135
Compact spec has the format:
@@ -174,7 +165,9 @@ class CompactSpec(RootModel):
174165
p: 0.1
175166
"""
176167

177-
root: dict[str, CompactValueSpec] = Field(
168+
model_config = ConfigDict(extra="forbid")
169+
170+
compact: CompactValueSpec = Field(
178171
description="Higher level compact spec that expands into Sequential spec.",
179172
examples=[
180173
{
@@ -201,17 +194,6 @@ class CompactSpec(RootModel):
201194
],
202195
)
203196

204-
@field_validator("root", mode="before")
205-
def is_single_key_dict(value: dict) -> dict:
206-
return NNSpec.is_single_key_dict(value)
207-
208-
@field_validator("root", mode="before")
209-
def key_is_compact(value: dict) -> dict:
210-
assert (
211-
next(iter(value)) == "compact"
212-
), "Key must be 'compact' if using CompactSpec."
213-
return value
214-
215197
def __expand_spec(self, compact_layer: dict) -> list[dict]:
216198
class_name = compact_layer["type"]
217199
keys = compact_layer["keys"]
@@ -223,7 +205,7 @@ def __expand_spec(self, compact_layer: dict) -> list[dict]:
223205
return nn_specs
224206

225207
def expand_to_sequential_spec(self) -> SequentialSpec:
226-
compact_spec = next(iter(self.root.values())).model_dump()
208+
compact_spec = self.compact.model_dump()
227209
prelayer = compact_spec.get("prelayer")
228210
postlayer = compact_spec.get("postlayer")
229211
nn_specs = []

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)