Skip to content

Commit 3b77638

Browse files
klemen1999kozlov721
authored andcommitted
Added name attribute to HeadType (#176)
1 parent c5cde8f commit 3b77638

File tree

5 files changed

+15
-3
lines changed

5 files changed

+15
-3
lines changed

luxonis_ml/nn_archive/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .model import Model
88

9-
CONFIG_VERSION = Literal["1.0"]
9+
CONFIG_VERSION = Literal["1.0", "1.1"]
1010

1111

1212
class Config(BaseModelExtraForbid):

luxonis_ml/nn_archive/config_building_blocks/base_models/head.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
class Head(BaseModel, ABC):
1717
"""Represents head of a model.
1818
19+
@type name: str | None
20+
@ivar name: Optional name of the head.
1921
@type parser: str
2022
@ivar parser: Name of the parser responsible for processing the models output.
2123
@type outputs: List[str] | None
@@ -27,6 +29,7 @@ class Head(BaseModel, ABC):
2729
@ivar metadata: Metadata of the parser.
2830
"""
2931

32+
name: Optional[str] = Field(None, description="Optional name of the head.")
3033
parser: str = Field(
3134
description="Name of the parser responsible for processing the models output."
3235
)

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
onnx>=1.14.0,<=1.16.2
12
pre-commit>=3.2.1
23
pytest-cov>=4.1.0
34
pytest-dependency>=0.6.0
45
pytest-subtests>=0.12.1
56
pytest-md>=0.2.0
67
gdown>=4.7.1
7-
coverage-badge>=1.1.0
8+
coverage-badge>=1.1.0

tests/test_nn_archive/heads.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156

157157
classification_head = dict(
158158
Head(
159+
name="ClassificationHead",
159160
parser="Classification",
160161
outputs=["output"],
161162
metadata=head_classification_metadata,
@@ -164,6 +165,7 @@
164165

165166
ssd_object_detection_head = dict(
166167
Head(
168+
name="ObjectDetectionSSDHead",
167169
parser="ObjectDetectionSSD",
168170
outputs=["boxes"],
169171
metadata=head_object_detection_ssd_metadata,
@@ -172,6 +174,7 @@
172174

173175
yolo_object_detection_head = dict(
174176
Head(
177+
name="YoloDetectionHead",
175178
parser="YOLO",
176179
outputs=["output"],
177180
metadata=head_yolo_obb_det_metadata,
@@ -180,6 +183,7 @@
180183

181184
yolo_instance_segmentation_head = dict(
182185
Head(
186+
name="YoloInstanceSegHead",
183187
parser="YOLO",
184188
outputs=["output"],
185189
metadata=head_yolo_instance_seg_metadata,
@@ -188,6 +192,7 @@
188192

189193
yolo_keypoint_detection_head = dict(
190194
Head(
195+
name="YoloKeypointDetectionHead",
191196
parser="YOLO",
192197
outputs=["output"],
193198
metadata=head_yolo_keypoint_det_metadata,
@@ -196,6 +201,7 @@
196201

197202
yolo_obb_detection_head = dict(
198203
Head(
204+
name="YoloOBBHead",
199205
parser="YOLO",
200206
outputs=["output"],
201207
metadata=head_yolo_obb_det_metadata,
@@ -204,6 +210,7 @@
204210

205211
yolo_instance_seg_kpts_head = dict(
206212
Head(
213+
name="YoloInstaceSegKptHead",
207214
parser="YOLO",
208215
outputs=["outputs"],
209216
metadata=head_yolo_instance_seg_kpts_metadata,
@@ -212,6 +219,7 @@
212219

213220
custom_segmentation_head = dict(
214221
Head(
222+
name="SegmentationHead",
215223
parser="Segmentation",
216224
outputs=["output"],
217225
metadata=head_segmentation_metadata,

tests/test_nn_archive/test_nn_archive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_archive_generator(
8383
archive_name=archive_name,
8484
save_path="tests/data/test_nn_archive",
8585
cfg_dict={
86-
"config_version": "1.0",
86+
"config_version": "1.1",
8787
"model": {
8888
"metadata": {
8989
"name": "test_model",

0 commit comments

Comments
 (0)