Skip to content

Commit 8ea28dd

Browse files
authored
Config list append in overrides (#359)
1 parent 9dc6c97 commit 8ea28dd

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

luxonis_ml/utils/config.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -170,37 +170,44 @@ def _merge_recursive(
170170
else:
171171
data[index] = parsed_value
172172
elif isinstance(data, list):
173-
raise ValueError(
174-
"Only int keys are allowed for list values"
175-
)
173+
if key == "+":
174+
data.append(parsed_value)
175+
else:
176+
raise ValueError(
177+
"Only int keys are allowed for list values"
178+
)
176179
else:
177180
data[key] = parsed_value
178181

179182
return
180183

181184
key_tail = ".".join(tail)
182185

183-
if key.isdecimal():
184-
index = int(key)
186+
if key.isdecimal() or key == "+":
185187
if not isinstance(data, list):
186188
raise ValueError(
187189
"int keys are not allowed for non-list values"
188190
)
189-
if index >= len(data):
190-
index = len(data)
191-
if data:
192-
data.append(type(data[0])())
193-
_merge_recursive(data[index], key_tail, value)
194-
else:
195-
# Try to guess type, backtrack if fails
196-
data.append([])
197-
try:
198-
_merge_recursive(data[index], key_tail, value)
199-
except Exception:
200-
data[index] = {}
201-
_merge_recursive(data[index], key_tail, value)
191+
if key == "+":
192+
data.append(type(data[0])())
193+
_merge_recursive(data[-1], key_tail, value)
202194
else:
203-
_merge_recursive(data[index], key_tail, value)
195+
index = int(key)
196+
if index >= len(data):
197+
index = len(data)
198+
if data:
199+
data.append(type(data[0])())
200+
_merge_recursive(data[index], key_tail, value)
201+
else:
202+
# Try to guess type, backtrack if fails
203+
data.append([])
204+
try:
205+
_merge_recursive(data[index], key_tail, value)
206+
except Exception:
207+
data[index] = {}
208+
_merge_recursive(data[index], key_tail, value)
209+
else:
210+
_merge_recursive(data[index], key_tail, value)
204211
else:
205212
if not isinstance(data, dict):
206213
raise ValueError(

tests/test_utils/test_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_config_simple_override(config_file: str):
114114
)
115115

116116

117-
def test_config_list_override(config_file: str):
117+
def test_config_override_as_list(config_file: str):
118118
overrides = ["sub_config.str_sub_param", "sub_param_override"]
119119
cfg = Config.get_config(config_file, overrides)
120120
assert cfg.sub_config.str_sub_param == overrides[1]
@@ -153,20 +153,24 @@ def test_config_override_list(config_file: str):
153153
"list_config.0.float_list_param": 2.5,
154154
"list_config.0.str_list_param": "test",
155155
"list_config.1.int_list_param": 20,
156+
"list_config.+.int_list_param": 30,
156157
"nested_list_param.0": [30],
157158
"nested_list_param.0.1": 40,
159+
"nested_list_param.0.+": 50,
158160
}
159161
cfg = Config.get_config(config_file, overrides)
160162
# Testing list configurations
161-
assert len(cfg.list_config) == 2
163+
assert len(cfg.list_config) == 3
162164
assert cfg.list_config[0].int_list_param == 10
163165
assert cfg.list_config[0].float_list_param == 2.5
164166
assert cfg.list_config[0].str_list_param == "test"
165167
assert cfg.list_config[1].int_list_param == 20
166168
assert cfg.list_config[1].float_list_param == 1.0
167169
assert cfg.list_config[1].str_list_param is None
170+
assert cfg.list_config[2].int_list_param == 30
168171
assert cfg.nested_list_param[0][0] == 30
169172
assert cfg.nested_list_param[0][1] == 40
173+
assert cfg.nested_list_param[0][2] == 50
170174

171175

172176
def test_config_list_override_json(config_file: str):

0 commit comments

Comments
 (0)