diff --git a/src/tlo/notify.py b/src/tlo/notify.py new file mode 100644 index 0000000000..b1b4434ba9 --- /dev/null +++ b/src/tlo/notify.py @@ -0,0 +1,72 @@ +""" +A dead simple synchronous notification dispatcher. + +Usage +----- +# In the notifying class/module +from tlo.notify import notifier + +notifier.dispatch("simulation.on_start", data={"one": 1, "two": 2}) + +# In the listening class/module +from tlo.notify import notifier + +def on_notification(data): + print("Received notification:", data) + +notifier.add_listener("simulation.on_start", on_notification) +""" + + +class Notifier: + """ + A simple synchronous notification dispatcher supporting listeners. + """ + + def __init__(self): + self.listeners = {} + + def add_listener(self, notification_key, listener): + """ + Register a listener for a specific notification. + + :param notification_key: The identifier to listen for. + :param listener: A callable to be invoked when the notification is dispatched. + """ + if notification_key not in self.listeners: + self.listeners[notification_key] = [] + self.listeners[notification_key].append(listener) + + def remove_listener(self, notification_key, listener): + """ + Remove a previously registered listener for a notification. + + :param notification_key: The identifier. + :param listener: The listener callable to remove. + """ + if notification_key in self.listeners: + self.listeners[notification_key].remove(listener) + if not self.listeners[notification_key]: + del self.listeners[notification_key] + + def dispatch(self, notification_key, data=None): + """ + Dispatch a notification to all registered listeners. + + :param notification_key: The identifier. + :param data: Optional data to pass to each listener. + """ + if notification_key in self.listeners: + for listener in self.listeners[notification_key]: + listener(data) + + def clear_listeners(self): + """ + Clear all registered listeners. Essential because the notifier is a global singleton. + e.g. if you are running multiple tests or simulations in the same process. + """ + self.listeners.clear() + + +# Create a global notifier instance +notifier = Notifier() diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index d2560f92d9..b0bd733234 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -26,6 +26,7 @@ topologically_sort_modules, ) from tlo.events import Event, IndividualScopeEventMixin +from tlo.notify import notifier from tlo.progressbar import ProgressBar if TYPE_CHECKING: @@ -116,6 +117,8 @@ def __init__( self._custom_log_levels = None self._log_filepath = self._configure_logging(**log_config) + # clear notifier listeners from any previous simulation in this process + notifier.clear_listeners() # random number generator seed_from = "auto" if seed is None else "user" diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 0000000000..ad5e828bbf --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,23 @@ +from tlo.notify import notifier + + +def test_notifier(): + # in listening code + received_data = [] + + def callback(data): + received_data.append(data) + + notifier.add_listener("test.signal", callback) + + # in emitting code + notifier.dispatch("test.signal", data={"value": 42}) + + assert len(received_data) == 1 + assert received_data[0] == {"value": 42} + + # Unsubscribe and test no further calls + notifier.remove_listener("test.signal", callback) + notifier.dispatch("test.signal", data={"value": 100}) + + assert len(received_data) == 1 # No new data