Skip to content

Commit 8f896ce

Browse files
authored
[Feat] Added name attribute to HeadType (#86)
1 parent 0c8cd21 commit 8f896ce

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

luxonis_train/core/utils/archive_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def get_heads(
235235
@param nodes: Dictionary of nodes.
236236
"""
237237
heads = []
238-
238+
head_names = set()
239239
for node in cfg.model.nodes:
240240
node_name = node.name
241241
node_alias = node.alias or node_name
@@ -250,7 +250,13 @@ def get_heads(
250250
head_outputs = _get_head_outputs(
251251
outputs, node_alias, node_name
252252
)
253+
if node_alias in head_names:
254+
curr_head_name = f"{node_alias}_{len(head_names)}" # add suffix if name is already present
255+
else:
256+
curr_head_name = node_alias
257+
head_names.add(curr_head_name)
253258
head_dict = {
259+
"name": curr_head_name,
254260
"parser": parser,
255261
"metadata": {
256262
"classes": classes,

luxonis_train/models/luxonis_lightning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,14 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]:
535535
for node_name, output_name, i in output_order
536536
]
537537

538+
if not self.cfg.exporter.output_names:
539+
idx = 1
540+
# Set to output names required by DAI
541+
for i, output_name in enumerate(output_names):
542+
if output_name.startswith("EfficientBBoxHead"):
543+
output_names[i] = f"output{idx}_yolov6r2"
544+
idx += 1
545+
538546
old_forward = self.forward
539547

540548
def export_forward(inputs) -> tuple[Tensor, ...]:

0 commit comments

Comments
 (0)