Skip to content

Commit 37aa2aa

Browse files
authored
Fix: Pathlib representer for yaml.safe_dump (#285)
1 parent f308a4c commit 37aa2aa

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

luxonis_ml/utils/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
from pathlib import PurePath
23
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
34

45
import yaml
@@ -78,6 +79,14 @@ def save_data(self, path: PathType) -> None:
7879
@type path: str
7980
@param path: Path to output yaml file.
8081
"""
82+
83+
def path_representer(
84+
dumper: yaml.SafeDumper, data: PurePath
85+
) -> yaml.ScalarNode:
86+
return dumper.represent_scalar("tag:yaml.org,2002:str", str(data))
87+
88+
yaml.SafeDumper.add_multi_representer(PurePath, path_representer)
89+
8190
with open(path, "w+") as f:
8291
yaml.safe_dump(self.model_dump(), f, default_flow_style=False)
8392

tests/test_utils/test_config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tempfile
22
from copy import deepcopy
3-
from typing import Dict, List, Optional
3+
from pathlib import Path
4+
from typing import Dict, List, Optional, Tuple
45

56
import pytest
67
import yaml
@@ -25,7 +26,7 @@
2526

2627

2728
@pytest.fixture
28-
def config_file():
29+
def config_file() -> str:
2930
with tempfile.NamedTemporaryFile(delete=False) as f:
3031
f.write(yaml.dump(CONFIG_DATA).encode())
3132
return f.name
@@ -55,6 +56,11 @@ class ListConfig(BaseModel):
5556
str_list_param: Optional[str] = None
5657

5758

59+
class UnsafeConfig(BaseModel):
60+
tuple_param: Tuple[int, int] = (1, 2)
61+
path_param: Path = Path.cwd()
62+
63+
5864
class Config(LuxonisConfig):
5965
sub_config: SubConfig
6066
sub_config_default: SubConfigDefault = SubConfigDefault()
@@ -63,6 +69,7 @@ class Config(LuxonisConfig):
6369
list_config: List[ListConfig] = []
6470
nested_list_param: List[List[int]] = []
6571
nested_dict_param: Dict[str, Dict[str, int]] = {}
72+
unsafe_config: UnsafeConfig = UnsafeConfig()
6673

6774

6875
def test_invalid_config_path():
@@ -267,3 +274,12 @@ def test_get(config_file: str):
267274

268275
def test_environ():
269276
assert environ.model_dump() == {}
277+
278+
279+
def test_safe_load(config_file: str):
280+
cfg = Config.get_config(config_file)
281+
with tempfile.NamedTemporaryFile(delete=False) as f:
282+
cfg.save_data(f.name)
283+
284+
cfg2 = Config.get_config(f.name)
285+
assert cfg == cfg2

0 commit comments

Comments
 (0)