Skip to content

Simple Predefined Models #255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: feature/cleanup-nodes
Choose a base branch
from

Conversation

kozlov721
Copy link
Collaborator

@kozlov721 kozlov721 commented Jul 3, 2025

Purpose

Simplifies the definition of predefined models

Specification

  • Added a new base class SimplePredefinedModel with common logic from individual predefined models
    • Subclass of BasePredefinedModel, doesn't replace it
      • Complicated models can still subclass BasePredefinedModel and implement all the parts themselves
    • Models subclassing this class need to only implement __init__ with default parameters
  • New API for variants (following the new variants API for nodes from feature/cleanup-nodes)
    • Allows for better overriding of individual parts of the predefined model
    • Moves common variant logic to the base class
    • New class method from_variant that constructs the predefined model from a variant name
    • variant is now a special field under model.predefined_model instead of if being under model.predefined_model.params
      • The params way is still accepted, just marked as deprecated

Example

Old Classification Model

from typing import Literal, TypeAlias

from luxonis_ml.typing import Params
from pydantic import BaseModel

from luxonis_train.config import (
    AttachedModuleConfig,
    LossModuleConfig,
    MetricModuleConfig,
    NodeConfig,
)

from .base_predefined_model import BasePredefinedModel

VariantLiteral: TypeAlias = Literal["light", "heavy"]


class ClassificationVariant(BaseModel):
    backbone: str
    backbone_params: Params


def get_variant(variant: VariantLiteral) -> ClassificationVariant:
    """Returns the specific variant configuration for the
    ClassificationModel."""
    variants = {
        "light": ClassificationVariant(
            backbone="ResNet", backbone_params={"variant": "18"}
        ),
        "heavy": ClassificationVariant(
            backbone="ResNet", backbone_params={"variant": "50"}
        ),
    }

    if variant not in variants:
        raise ValueError(
            f"Classification variant should be one of {list(variants.keys())}, got '{variant}'."
        )

    return variants[variant]


class ClassificationModel(BasePredefinedModel):
    def __init__(
        self,
        variant: VariantLiteral = "light",
        backbone: str | None = None,
        backbone_params: Params | None = None,
        head_params: Params | None = None,
        loss_params: Params | None = None,
        visualizer_params: Params | None = None,
        task: Literal["multiclass", "multilabel"] = "multiclass",
        task_name: str = "",
        enable_confusion_matrix: bool = True,
        confusion_matrix_params: Params | None = None,
    ):
        var_config = get_variant(variant)

        self.backbone = backbone or var_config.backbone
        self.backbone_params = (
            backbone_params
            if backbone is not None or backbone_params is not None
            else var_config.backbone_params
        ) or {}
        self.head_params = head_params or {}
        self.loss_params = loss_params or {}
        self.visualizer_params = visualizer_params or {}
        self.task = task
        self.task_name = task_name
        self.enable_confusion_matrix = enable_confusion_matrix
        self.confusion_matrix_params = confusion_matrix_params or {}

    @property
    def nodes(self) -> list[NodeConfig]:
        """Defines the model nodes, including backbone and head."""
        return [
            NodeConfig(
                name=self.backbone,
                freezing=self._get_freezing(self.backbone_params),
                params=self.backbone_params,
            ),
            NodeConfig(
                name="ClassificationHead",
                freezing=self._get_freezing(self.head_params),
                inputs=[self.backbone],
                params=self.head_params,
                task_name=self.task_name,
            ),
        ]

    @property
    def losses(self) -> list[LossModuleConfig]:
        """Defines the loss module for the classification task."""
        return [
            LossModuleConfig(
                name="CrossEntropyLoss",
                attached_to="ClassificationHead",
                params=self.loss_params,
                weight=1.0,
            )
        ]

    @property
    def metrics(self) -> list[MetricModuleConfig]:
        """Defines the metrics used for evaluation."""
        metrics = [
            MetricModuleConfig(
                name="F1Score",
                is_main_metric=True,
                attached_to="ClassificationHead",
                params={"task": self.task},
            ),
            MetricModuleConfig(
                name="Accuracy",
                attached_to="ClassificationHead",
                params={"task": self.task},
            ),
            MetricModuleConfig(
                name="Recall",
                attached_to="ClassificationHead",
                params={"task": self.task},
            ),
        ]
        if self.enable_confusion_matrix:
            metrics.append(
                MetricModuleConfig(
                    name="ConfusionMatrix",
                    attached_to="ClassificationHead",
                    params={**self.confusion_matrix_params},
                )
            )
        return metrics

    @property
    def visualizers(self) -> list[AttachedModuleConfig]:
        """Defines the visualizer used for the classification task."""
        return [
            AttachedModuleConfig(
                name="ClassificationVisualizer",
                attached_to="ClassificationHead",
                params=self.visualizer_params,
            )
        ]

Simple Classification Model

from luxonis_ml.typing import Params

from .base_predefined_model import SimplePredefinedModel


class ClassificationModel(SimplePredefinedModel):
    def __init__(self, **kwargs):
        # `__init__` kwargs need to be handled like a dictionary
        # to allow CLI overrides
        kwargs = {
            "backbone": "ResNet",
            "head": "ClassificationHead",
            "loss": "CrossEntropyLoss",
            "metrics": ["F1Score", "Accuracy", "Recall"],
            "confusion_matrix_available": True,
            "main_metric": "F1Score",
            "visualizer": "ClassificationVisualizer",
        } | kwargs
        super().__init__(**kwargs)

    @staticmethod
    def get_variants() -> tuple[str, dict[str, Params]]:
        return "light", {
            "light": {
                "backbone": "ResNet",
                "backbone_params": {"variant": "18"},
            },
            "heavy": {
                "backbone": "ResNet",
                "backbone_params": {"variant": "50"},
            },
        }

Dependencies & Potential Impact

None / not applicable

Deployment Plan

None / not applicable

Testing & Validation

None / not applicable

@kozlov721 kozlov721 requested a review from a team as a code owner July 3, 2025 08:58
@kozlov721 kozlov721 requested review from klemen1999, tersekmatija and conorsim and removed request for a team July 3, 2025 08:58
@github-actions github-actions bot added the enhancement New feature or request label Jul 3, 2025
Copy link
Collaborator

@klemen1999 klemen1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants