diff --git a/src/yosys_mau/task_loop/context.py b/src/yosys_mau/task_loop/context.py index 711a053..eb40288 100644 --- a/src/yosys_mau/task_loop/context.py +++ b/src/yosys_mau/task_loop/context.py @@ -205,6 +205,7 @@ def __getitem__(self, key: K) -> V: else: if isinstance(value, _MISSING_TYPE): raise KeyError(repr(key)) + return value return self.__default[key] def __setitem__(self, key: K, value: V) -> None: diff --git a/tests/task_loop/test_context_vars.py b/tests/task_loop/test_context_vars.py index 4d806aa..a71a472 100644 --- a/tests/task_loop/test_context_vars.py +++ b/tests/task_loop/test_context_vars.py @@ -2,6 +2,7 @@ import pytest import yosys_mau.task_loop as tl +from yosys_mau.task_loop.context import TaskContextDict def test_local_override_stays_local(): @@ -205,4 +206,96 @@ def on_task3(): assert order == [1, 3, 3] +def test_TaskContextDict_with_default(): + @tl.task_context + class SomeContext: + some_var: TaskContextDict[str, str] = TaskContextDict() + + def main(): + # iterate default values + for _, _ in SomeContext.some_var.items(): + pass + + # iterate non-default values + SomeContext.some_var["a"] = "b" + for _, _ in SomeContext.some_var.items(): + pass + + assert SomeContext.some_var["a"] == "b" + + tl.run_task_loop(main) + + +def test_override_TaskContextDict(): + order: list[dict[str, str]] = [] + + @tl.task_context + class SomeContext: + some_var: TaskContextDict[str, str] = TaskContextDict() + + def main(): + def on_task1(): + SomeContext.some_var["a"] = "b" + order.append(SomeContext.some_var.as_dict()) + + def on_task2(): + order.append(SomeContext.some_var.as_dict()) + + with tl.root_task().as_current_task(): + SomeContext.some_var["b"] = "d" + + def on_task3(): + order.append(SomeContext.some_var.as_dict()) + + task1 = tl.Task(on_run=on_task1) + task2 = tl.Task(on_run=on_task2) + task3 = tl.Task(on_run=on_task3) + + task2.depends_on(task1) + task3.depends_on(task2) + + SomeContext.some_var["b"] = "c" + + tl.run_task_loop(main) + + assert order == [ + {"a": "b", "b": "c"}, + {"b": "c"}, + {"b": "d"}, + ] + + +def test_child_TaskContextDict(): + order: list[dict[str, str]] = [] + + @tl.task_context + class SomeContext: + some_var: TaskContextDict[str, str] = TaskContextDict() + + def main(): + async def on_task1(): + SomeContext.some_var["a"] = "b" + order.append(SomeContext.some_var.as_dict()) + t2 = tl.Task(on_run=on_task2) + await t2.finished + order.append(SomeContext.some_var.as_dict()) + + def on_task2(): + SomeContext.some_var["b"] = "d" + del SomeContext.some_var["a"] + order.append(SomeContext.some_var.as_dict()) + + tl.Task(on_run=on_task1) + + SomeContext.some_var["b"] = "c" + + tl.run_task_loop(main) + + assert order == [ + {"a": "b", "b": "c"}, + {"b": "d"}, + {"a": "b", "b": "c"}, + ] + + # TODO tests