Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
__version__ = "0.4.0"
import pathlib
import sys
from typing import Final

from pydantic_extra_types.semantic_version import SemanticVersion

__version__: Final[str] = "0.4.0"
__semver__: Final[SemanticVersion] = SemanticVersion.parse(__version__)


# Do not run imports when first importing from within the CLI
# This is to make the CLI more responsive
Expand All @@ -9,6 +14,7 @@
or "--source" in sys.argv
or not sys.argv[0].endswith("/luxonis_train")
):
import pathlib
import warnings

try:
Expand Down
156 changes: 117 additions & 39 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import importlib
import importlib.util
import subprocess
import json
import sys
from collections.abc import Iterator
from functools import lru_cache
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Literal

import requests
from cyclopts import App, Group, Parameter
import yaml
from cyclopts import App, Group, Parameter, validators
from loguru import logger
from luxonis_ml.typing import PathType

from luxonis_train.config import Config
from luxonis_train.upgrade import upgrade_config, upgrade_installation

if TYPE_CHECKING:
import numpy as np
Expand All @@ -24,6 +29,8 @@
app["--help"].group = app.meta.group_parameters
app["--version"].group = app.meta.group_parameters

upgrade_app = app.command(App(name="upgrade"))

training_group = Group.create_ordered("Training")
evaluation_group = Group.create_ordered("Evaluation")
export_group = Group.create_ordered("Export")
Expand All @@ -32,18 +39,46 @@


def create_model(
config: str | None,
opts: list[str] | None,
weights: str | None = None,
**kwargs,
config: PathType | None,
opts: list[str] | None = None,
weights: PathType | None = None,
debug_mode: bool = False,
load_dataset_metadata: bool = True,
) -> "LuxonisModel":
importlib.reload(sys.modules["luxonis_train"])
import torch

from luxonis_train import LuxonisModel
from luxonis_train.utils.dataset_metadata import DatasetMetadata

if weights is not None and config is None:
return LuxonisModel.from_checkpoint(weights, opts, **kwargs)
ckpt = torch.load(weights, map_location="cpu") # nosemgre

Check failure on line 55 in luxonis_train/__main__.py

View workflow job for this annotation

GitHub Actions / semgrep/ci

Semgrep Issue

Functions reliant on pickle can result in arbitrary code execution. Consider loading from `state_dict`, using fickling, or switching to a safer serialization method like ONNX
if "config" not in ckpt: # pragma: no cover
raise ValueError(
f"Checkpoint '{weights}' does not contain the 'config' key. "
"Cannot restore `LuxonisModel` from checkpoint."
)
cfg = Config.get_config(upgrade_config(ckpt["config"]), opts)
dataset_metadata = None
if load_dataset_metadata:
if "dataset_metadata" not in ckpt:
logger.error("Checkpoint does not contain dataset metadata.")
else:
try:
dataset_metadata = DatasetMetadata(
**ckpt["dataset_metadata"]
)
except Exception as e: # pragma: no cover
logger.error(
"Failed to load dataset metadata from the checkpoint. "
f"Error: {e}"
)

return LuxonisModel(
cfg, debug_mode=debug_mode, dataset_metadata=dataset_metadata
)

return LuxonisModel(config, opts, **kwargs)
return LuxonisModel(config, opts, debug_mode=debug_mode)


@app.command(group=training_group, sort_key=1)
Expand Down Expand Up @@ -366,37 +401,80 @@
)


@app.command(group=management_group)
def upgrade():
"""Update LuxonisTrain to the latest stable version."""

def get_latest_version() -> str | None:
url = "https://pypi.org/pypi/luxonis_train/json"
response = requests.get(url, timeout=5)
if response.status_code == 200:
data = response.json()
versions = list(data["releases"].keys())
versions.sort(key=lambda s: [int(u) for u in s.split(".")])
return versions[-1]
return None

current_version = version("luxonis_train")
latest_version = get_latest_version()
if latest_version is None:
print("Failed to check for updates. Try again later.")
return
if current_version == latest_version:
print(f"LuxonisTrain is up-to-date (v{current_version}).")
@upgrade_app.command()
def config(
config: Annotated[
Path,
Parameter(validator=validators.Path(exists=True)),
Parameter(validator=validators.Path(ext={"yaml", "yml", "json"})),
],
output: Annotated[
Path | None,
Parameter(validator=validators.Path(ext={"yaml", "yml", "json"})),
] = None,
):
"""Upgrade luxonis-train configuration file.

@type path: Path
@param old: Path to configuration file to be upgraded.
@type output: Path | None
@param new: Where to save the upgraded config. If left empty, the
old file will be overriden.
"""
if config.suffix == "json":
cfg = json.loads(config.read_text(encoding="utf-8"))
else:
subprocess.check_output(
f"{sys.executable} -m pip install -U pip".split()
)
subprocess.check_output(
f"{sys.executable} -m pip install -U luxonis_train".split()
)
print(
f"LuxonisTrain updated from v{current_version} to v{latest_version}."
)
cfg = yaml.safe_load(config.read_text(encoding="utf-8"))

new_cfg = upgrade_config(cfg)

output = output or config
if output.suffix == "json":
output.write_text(json.dumps(new_cfg, indent=2))
else:
with open(output, "w") as f:
yaml.safe_dump(
new_cfg, f, sort_keys=False, default_flow_style=False
)


@upgrade_app.command(name=["checkpoint", "ckpt"])
def checkpoint(
path: Annotated[
Path,
Parameter(validator=validators.Path(exists=True)),
],
output: Path | None = None,
):
"""Upgrade luxonis-train checkpoint file.

@type path: Path
@param old: Path to the checkpoint
@type output: Path | None
@param new: Where to save the upgraded checkpoint. If left empty,
the old file will be overriden.
"""
model = create_model(config=None, weights=path)
model.lightning_module.load_checkpoint(path)

# Needs to be called in order to attach the model to the trainer
model.pl_trainer.validate(
model.lightning_module,
model.pytorch_loaders["val"],
verbose=False,
)
model.pl_trainer.save_checkpoint(output or path, weights_only=False)
logger.info(f"Saved upgraded checkpoint to '{output}'")


@upgrade_app.default()
def upgrade():
"""Upgrade luxonis-train installation and user files.

Usage without a subcommand will trigger an upgrade of `luxonis-
train` PyPI package.
"""
upgrade_installation()


@app.meta.default
Expand Down
2 changes: 0 additions & 2 deletions luxonis_train/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .config import (
CONFIG_VERSION,
AttachedModuleConfig,
Config,
ExportConfig,
Expand All @@ -14,7 +13,6 @@
from .predefined_models.base_predefined_model import BasePredefinedModel

__all__ = [
"CONFIG_VERSION",
"AttachedModuleConfig",
"BasePredefinedModel",
"Config",
Expand Down
13 changes: 5 additions & 8 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Mapping
from contextlib import suppress
from pathlib import Path
from typing import Annotated, Any, Final, Literal, NamedTuple
from typing import Annotated, Any, Literal, NamedTuple

from loguru import logger
from luxonis_ml.enums import DatasetType
Expand Down Expand Up @@ -39,12 +39,9 @@
from pydantic_extra_types.semantic_version import SemanticVersion
from typing_extensions import Self, override

import luxonis_train as lxt
from luxonis_train.registry import MODELS, NODES, from_registry

CONFIG_VERSION: Final[SemanticVersion] = SemanticVersion.parse(
"2.0", optional_minor_and_patch=True
)


class ImageSize(NamedTuple):
height: int
Expand Down Expand Up @@ -633,9 +630,9 @@ class Config(LuxonisConfig):
archiver: ArchiveConfig = Field(default_factory=ArchiveConfig)
tuner: TunerConfig = Field(default_factory=TunerConfig)

config_version: Annotated[SemanticVersion, PlainSerializer(str)] = (
CONFIG_VERSION
)
version: Annotated[
SemanticVersion, Field(frozen=True), PlainSerializer(str)
] = lxt.__semver__

ENVIRON: Environ = Field(exclude=True, default_factory=Environ)

Expand Down
65 changes: 6 additions & 59 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,64 +285,6 @@ def __init__(

self._exported_models: dict[str, Path] = {}

@classmethod
def from_checkpoint(
cls,
path: PathType,
opts: Params | list[str] | tuple[str, ...] | None = None,
*,
debug_mode: bool = False,
load_dataset_metadata: bool = True,
) -> "LuxonisModel":
"""Creates a LuxonisModel instance from a checkpoint file.

@type path: PathType
@param path: Path to the checkpoint file.
@type opts: Params | list[str] | tuple[str, ...] | None
@param opts: Argument dict provided through command line, used to config overriding.
@type debug_mode: bool
@param debug_mode: If set to True, enables debug mode which ignores some
normaly unrecovarable exceptions and allows to test the model
without it being fully functional.
@type load_dataset_metadata: bool
@param load_dataset_metadata: If set to True, attempts to load dataset
metadata from the checkpoint. If the metadata is not found in the
checkpoint, it will be loaded from the training data.
"""
ckpt = torch.load(path, map_location="cpu")
if "config" not in ckpt: # pragma: no cover
raise ValueError(
f"Checkpoint '{path}' does not contain the 'config' key. "
"Cannot restore `LuxonisModel` from checkpoint."
)
try:
cfg = Config.get_config(ckpt["config"], opts)
except Exception as e: # pragma: no cover
raise ValueError(
"Failed to load config from the checkpoint. "
"This can happen if the config schema changed "
"between the version used to create the checkpoint "
"and the current version of luxonis-train."
) from e
dataset_metadata = None
if load_dataset_metadata:
if "dataset_metadata" not in ckpt:
logger.error("Checkpoint does not contain dataset metadata.")
else:
try:
dataset_metadata = DatasetMetadata(
**ckpt["dataset_metadata"]
)
except Exception as e: # pragma: no cover
logger.error(
"Failed to load dataset metadata from the checkpoint. "
f"Error: {e}"
)

return cls(
cfg, debug_mode=debug_mode, dataset_metadata=dataset_metadata
)

def _train(self, resume: PathType | None, *args, **kwargs) -> None:
status = "success"
try:
Expand Down Expand Up @@ -594,7 +536,12 @@ def export(
}

with open(export_path.with_suffix(".yaml"), "w") as f:
yaml.safe_dump(modelconverter_config, f)
yaml.safe_dump(
modelconverter_config,
f,
sort_keys=False,
default_flow_style=False,
)
if self.cfg.exporter.upload_to_run:
self.tracker.upload_artifact(f.name, name=f.name, typ="export")
if self.cfg.exporter.upload_url is not None: # pragma: no cover
Expand Down
25 changes: 13 additions & 12 deletions luxonis_train/lightning/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,29 +543,30 @@
self.cfg, self.parameters(), self.nodes.main_metric, self.nodes
)

def load_checkpoint(self, path: PathType | None) -> None:
def load_checkpoint(self, ckpt: PathType | dict[str, Any] | None) -> None:
"""Loads checkpoint weights from provided path.

Loads the checkpoints gracefully, ignoring keys that are not
found in the model state dict or in the checkpoint.

@type path: PathType | None
@param path: Path to the checkpoint. If C{None}, no checkpoint
will be loaded.
@type ckpt: PathType | dict | None
@param path: Either a path to or a loaded checkpoint. If
C{None}, no checkpoint will be loaded.
"""
if path is None:
if ckpt is None:
return

path = str(path)
if isinstance(ckpt, str | Path):
ckpt = cast(
dict[str, Any], torch.load(ckpt, map_location=self.device)

Check failure on line 561 in luxonis_train/lightning/luxonis_lightning.py

View workflow job for this annotation

GitHub Actions / semgrep/ci

Semgrep Issue

Functions reliant on pickle can result in arbitrary code execution. Consider loading from `state_dict`, using fickling, or switching to a safer serialization method like ONNX
) # nosemgrep

checkpoint = torch.load(path, map_location=self.device) # nosemgrep

if "state_dict" not in checkpoint:
if "state_dict" not in ckpt:
raise ValueError("Checkpoint does not contain state_dict.")

state_dict = checkpoint["state_dict"]
order_mapping = self._load_execution_order_mapping(checkpoint)
ver = version.parse(checkpoint.get("version", "0.3.0"))
state_dict = ckpt["state_dict"]
order_mapping = self._load_execution_order_mapping(ckpt)
ver = version.parse(ckpt.get("version", "0.3.0"))

for node_name, node in self.nodes.items():
sub_state_dict = {
Expand Down
Loading
Loading