From f62bbfc0b3a724018a4f5c5b92dca0855ab52678 Mon Sep 17 00:00:00 2001 From: "Johannes.Hennecke" Date: Thu, 16 Nov 2023 14:55:29 +0100 Subject: [PATCH] repeat subscribe when reconnected to MQTT broker; added comments --- mqtt_io/events.py | 12 ++ mqtt_io/mqtt/asyncio_mqtt.py | 86 ++++++++++- mqtt_io/server.py | 266 +++++++++++++++++++++++++++++------ 3 files changed, 318 insertions(+), 46 deletions(-) diff --git a/mqtt_io/events.py b/mqtt_io/events.py index e58e4729..9889d34c 100644 --- a/mqtt_io/events.py +++ b/mqtt_io/events.py @@ -74,6 +74,18 @@ class StreamDataSentEvent(Event): stream_name: str data: bytes +@dataclass +class StreamDataSubscribeEvent(Event): + """ + Trigger MQTT subscribe + """ + +@dataclass +class DigitalSubscribeEvent(Event): + """ + Trigger MQTT subscribe + """ + class EventBus: """ diff --git a/mqtt_io/mqtt/asyncio_mqtt.py b/mqtt_io/mqtt/asyncio_mqtt.py index 47f406e6..e1c3fff5 100644 --- a/mqtt_io/mqtt/asyncio_mqtt.py +++ b/mqtt_io/mqtt/asyncio_mqtt.py @@ -26,8 +26,25 @@ def _map_exception(func: Func) -> Func: + """ + Creates a decorator that wraps a function and maps any raised `MqttError` exception to a `MQTTException`. + + :param func: The function to be wrapped. + :type func: Func + :return: The wrapped function. + :rtype: Func + """ @wraps(func) async def inner(*args: Any, **kwargs: Any) -> Any: + """ + Decorator for asynchronous functions that catches `MqttError` exceptions and raises `MQTTException` instead. + + Parameters: + func (Callable): The function to be decorated. + + Returns: + Callable: The decorated function. + """ try: await func(*args, **kwargs) except MqttError as exc: @@ -42,6 +59,15 @@ class MQTTClient(AbstractMQTTClient): """ def __init__(self, options: MQTTClientOptions): + """ + Initializes a new instance of the MQTTClient class. + + Args: + options (MQTTClientOptions): The options for the MQTT client. + + Returns: + None + """ super().__init__(options) protocol_map = { MQTTProtocol.V31: paho.MQTTv31, @@ -66,7 +92,7 @@ def __init__(self, options: MQTTClientOptions): username=options.username, password=options.password, client_id=options.client_id, - # keepalive=options.keepalive, # This isn't implemented yet on 0.8.1 + keepalive=options.keepalive, tls_context=tls_context, protocol=protocol_map[options.protocol], will=will, @@ -76,10 +102,28 @@ def __init__(self, options: MQTTClientOptions): @_map_exception async def connect(self, timeout: int = 10) -> None: + """ + Connects to the client asynchronously. + + Args: + timeout (int): The timeout value in seconds (default: 10). + + Returns: + None: This function does not return anything. + """ await self._client.connect(timeout=timeout) @_map_exception async def disconnect(self) -> None: + """ + This function is an asynchronous method that handles the disconnection of the client. + + Parameters: + self: The current instance of the class. + + Returns: + None + """ try: await self._client.disconnect() except TimeoutError: @@ -87,10 +131,33 @@ async def disconnect(self) -> None: @_map_exception async def subscribe(self, topics: List[Tuple[str, int]]) -> None: + """ + Subscribe to the given list of topics. + + Args: + topics (List[Tuple[str, int]]): A list of tuples representing the topics to subscribe to. + Each tuple should contain a string representing the topic name and an integer representing the QoS level. + + Returns: + None: This function does not return anything. + + Raises: + Exception: If there is an error while subscribing to the topics. + + """ await self._client.subscribe(topics) @_map_exception async def publish(self, msg: MQTTMessageSend) -> None: + """ + Publishes an MQTT message to the specified topic. + + Args: + msg (MQTTMessageSend): The MQTT message to be published. + + Returns: + None: This function does not return anything. + """ await self._client.publish( topic=msg.topic, payload=msg.payload, qos=msg.qos, retain=msg.retain ) @@ -98,6 +165,17 @@ async def publish(self, msg: MQTTMessageSend) -> None: def _on_message( self, client: paho.Client, userdata: Any, msg: paho.MQTTMessage ) -> None: + """ + Callback function that is called when a message is received through MQTT. + + Args: + client (paho.Client): The MQTT client instance. + userdata (Any): The user data associated with the client. + msg (paho.MQTTMessage): The received MQTT message. + + Returns: + None: This function does not return anything. + """ if self._message_queue is None: _LOG.warning("Discarding MQTT message because queue is not initialised") return @@ -111,6 +189,12 @@ def _on_message( @property def message_queue(self) -> "asyncio.Queue[MQTTMessage]": + """ + Returns the message queue for receiving MQTT messages. + + :return: The message queue for receiving MQTT messages. + :rtype: asyncio.Queue[MQTTMessage] + """ if self._message_queue is None: self._message_queue = asyncio.Queue(self._options.message_queue_size) # pylint: disable=protected-access diff --git a/mqtt_io/server.py b/mqtt_io/server.py index f026cefc..36402083 100644 --- a/mqtt_io/server.py +++ b/mqtt_io/server.py @@ -51,6 +51,8 @@ SensorReadEvent, StreamDataReadEvent, StreamDataSentEvent, + StreamDataSubscribeEvent, + DigitalSubscribeEvent, ) from .home_assistant import ( hass_announce_digital_input, @@ -106,7 +108,7 @@ def _init_module( module_config: Dict[str, Dict[str, Any]], module_type: str, install_requirements: bool ) -> Union[GenericGPIO, GenericSensor, GenericStream]: """ - Initialise a GPIO module by: + Initialise a module by: - Importing it - Validating its config - Installing any missing requirements for it @@ -147,6 +149,16 @@ class MqttIo: # pylint: disable=too-many-instance-attributes def __init__( self, config: Dict[str, Any], loop: Optional[asyncio.AbstractEventLoop] = None ) -> None: + """ + Initializes the class with the given configuration and event loop. + + Parameters: + config (Dict[str, Any]): The configuration for the class. + loop (Optional[asyncio.AbstractEventLoop]): The event loop to use. If not provided, the default event loop will be used. + + Returns: + None + """ self.config = config self._init_mqtt_config() @@ -197,6 +209,21 @@ async def create_loop_resources() -> None: self.loop.run_until_complete(create_loop_resources()) def _init_mqtt_config(self) -> None: + """ + Initializes the MQTT configuration. + + This function retrieves the MQTT configuration from the application's + main configuration file and performs the necessary setup tasks. + It sets the topic prefix for MQTT messages, replaces the '' placeholder + in the topic prefix with the IP address of the 'wlan0' interface, + generates a client ID if not provided, and configures TLS options if enabled. + + Parameters: + None + + Returns: + None + """ config: ConfigType = self.config["mqtt"] topic_prefix: str = config["topic_prefix"] @@ -289,7 +316,6 @@ async def publish_stream_data_callback(event: StreamDataReadEvent) -> None: self.stream_configs = {x["name"]: x for x in self.config["stream_modules"]} self.stream_modules = {} - sub_topics: List[str] = [] for stream_conf in self.config["stream_modules"]: stream_module = _init_module( stream_conf, "stream", self.config["options"]["install_requirements"] @@ -324,22 +350,27 @@ async def create_stream_output_queue( ) ) - sub_topics.append( - "/".join( - ( - self.config["mqtt"]["topic_prefix"], - STREAM_TOPIC, - stream_conf["name"], - SEND_SUFFIX, + # Subscribe call back funktion: Subscribe to stream send topics + async def subscribe_callback(event: StreamDataSubscribeEvent) -> None: + sub_topics: List[str] = [] + for stream_conf in self.config["stream_modules"]: + sub_topics.append( + "/".join( + ( + self.config["mqtt"]["topic_prefix"], + STREAM_TOPIC, + stream_conf["name"], + SEND_SUFFIX, + ) ) ) - ) + + if sub_topics: + self.mqtt_task_queue.put_nowait( + PriorityCoro(self._mqtt_subscribe(sub_topics), MQTT_SUB_PRIORITY) + ) - # Subscribe to stream send topics - if sub_topics: - self.mqtt_task_queue.put_nowait( - PriorityCoro(self._mqtt_subscribe(sub_topics), MQTT_SUB_PRIORITY) - ) + self.event_bus.subscribe(StreamDataSubscribeEvent, subscribe_callback) def _init_digital_inputs(self) -> None: """ @@ -420,8 +451,33 @@ async def publish_callback(event: DigitalInputChangedEvent) -> None: ) def _init_digital_outputs(self) -> None: + """ + Initializes the digital outputs. + + This function sets up the MQTT publish callback for the output event. + It creates a digital output queue on the right loop for each module. + It subscribes to outputs when MQTT is initialized. + It also fires DigitalOutputChangedEvents for the initial values of + outputs if required, and reads and publishes the actual pin state + if no publish_initial is requested. + + Parameters: + None + + Returns: + None + """ # Set up MQTT publish callback for output event async def publish_callback(event: DigitalOutputChangedEvent) -> None: + """ + Publishes a callback function for the given DigitalOutputChangedEvent. + + Args: + event (DigitalOutputChangedEvent): The event object containing the details of the digital output change. + + Returns: + None + """ out_conf = self.digital_output_configs[event.output_name] val = out_conf["on_payload"] if event.to_value else out_conf["off_payload"] self.mqtt_task_queue.put_nowait( @@ -477,23 +533,6 @@ async def create_digital_output_queue( ) ) - # Add tasks to subscribe to outputs when MQTT is initialised - topics = [] - for suffix in (SET_SUFFIX, SET_ON_MS_SUFFIX, SET_OFF_MS_SUFFIX): - topics.append( - "/".join( - ( - self.config["mqtt"]["topic_prefix"], - OUTPUT_TOPIC, - out_conf["name"], - suffix, - ) - ) - ) - self.mqtt_task_queue.put_nowait( - PriorityCoro(self._mqtt_subscribe(topics), MQTT_SUB_PRIORITY) - ) - # Fire DigitalOutputChangedEvents for initial values of outputs if required if out_conf["publish_initial"]: self.event_bus.fire( @@ -515,8 +554,47 @@ async def create_digital_output_queue( ) self.event_bus.fire(DigitalOutputChangedEvent(out_conf["name"], value)) + # Subscribe call back funktion: Add tasks to subscribe to outputs when MQTT is initialised + async def subscribe_callback(event: DigitalSubscribeEvent) -> None: + for out_conf in self.config["digital_outputs"]: + topics = [] + for suffix in (SET_SUFFIX, SET_ON_MS_SUFFIX, SET_OFF_MS_SUFFIX): + topics.append( + "/".join( + ( + self.config["mqtt"]["topic_prefix"], + OUTPUT_TOPIC, + out_conf["name"], + suffix, + ) + ) + ) + self.mqtt_task_queue.put_nowait( + PriorityCoro(self._mqtt_subscribe(topics), MQTT_SUB_PRIORITY) + ) + + self.event_bus.subscribe(DigitalSubscribeEvent, subscribe_callback) + def _init_sensor_inputs(self) -> None: + """ + Initializes the sensor inputs for the class. + + Parameters: + None + + Returns: + None + """ async def publish_sensor_callback(event: SensorReadEvent) -> None: + """ + Publishes a sensor callback event to the MQTT broker. + + Args: + event (SensorReadEvent): The event object containing the sensor data. + + Returns: + None + """ sens_conf = self.sensor_input_configs[event.sensor_name] digits: int = sens_conf["digits"] self.mqtt_task_queue.put_nowait( @@ -554,6 +632,17 @@ async def poll_sensor( sensor_module: GenericSensor = sensor_module, sens_conf: ConfigType = sens_conf, ) -> None: + """ + Asynchronously polls a sensor to retrieve its value at regular intervals. + + Args: + sensor_module (Optional[GenericSensor]): The sensor module to use. Defaults to the sensor_module provided during function call. + sens_conf (Optional[ConfigType]): The configuration for the sensor. Defaults to the sens_conf provided during function call. + + Returns: + None + + """ @backoff.on_exception( # type: ignore backoff.expo, Exception, max_time=sens_conf["interval"] ) @@ -564,6 +653,16 @@ async def get_sensor_value( sensor_module: GenericSensor = sensor_module, sens_conf: ConfigType = sens_conf, ) -> SensorValueType: + """ + A decorator that applies exponential backoff to the function `get_sensor_value`. + + Parameters: + sensor_module (GenericSensor): The sensor module to use for getting the sensor value. + sens_conf (ConfigType): The configuration for the sensor. + + Returns: + SensorValueType: The value retrieved from the sensor. + """ return await sensor_module.async_get_value(sens_conf) while True: @@ -586,6 +685,12 @@ async def get_sensor_value( self.transient_tasks.append(self.loop.create_task(poll_sensor())) async def _connect_mqtt(self) -> None: + """ + Connects to the MQTT broker and sets up the necessary configurations. + + Returns: + None + """ config: ConfigType = self.config["mqtt"] topic_prefix: str = config["topic_prefix"] self.mqtt = AbstractMQTTClient.get_implementation(config["client_module"])( @@ -610,6 +715,9 @@ async def _connect_mqtt(self) -> None: ) ) self.mqtt_connected.set() + self.event_bus.fire(StreamDataSubscribeEvent()) + self.event_bus.fire(DigitalSubscribeEvent()) + def _ha_discovery_announce(self) -> None: """ @@ -656,6 +764,19 @@ async def _mqtt_subscribe(self, topics: List[str]) -> None: _LOG.info("Subscribed to topic: %r", topic) async def _mqtt_publish(self, msg: MQTTMessageSend, wait: bool = True) -> None: + """ + Publishes an MQTT message. + + Args: + msg (MQTTMessageSend): The MQTT message to publish. + wait (bool, optional): Whether to wait for MQTT connection before publishing. Defaults to True. + + Raises: + RuntimeError: If the MQTT client is None. + + Returns: + None + """ if not self.mqtt_connected.is_set(): if wait: _LOG.debug("_mqtt_publish awaiting MQTT connection") @@ -666,18 +787,22 @@ async def _mqtt_publish(self, msg: MQTTMessageSend, wait: bool = True) -> None: if msg.payload is None: _LOG.debug("Publishing MQTT message on topic %r with no payload", msg.topic) + elif type(msg.payload) == bytes or type(msg.payload) == bytearray: + try: + payload = msg.payload.decode("utf8") + except UnicodeDecodeError: + _LOG.debug( + "Publishing MQTT message on topic %r with non-unicode payload", + msg.topic, + ) + else: + _LOG.debug( + "Publishing MQTT message on topic %r: %r", msg.topic, payload + ) else: - try: - payload_str = msg.payload.decode("utf8") - except UnicodeDecodeError: - _LOG.debug( - "Publishing MQTT message on topic %r with non-unicode payload", - msg.topic, - ) - else: - _LOG.debug( - "Publishing MQTT message on topic %r: %r", msg.topic, payload_str - ) + _LOG.debug( + f"Publishing MQTT message on topic {msg.topic}: {msg.payload}" + ) await self.mqtt.publish(msg) @@ -1077,6 +1202,27 @@ async def _mqtt_keep_alive_loop(self) -> None: await asyncio.sleep(config["keepalive"]) async def _mqtt_rx_loop(self) -> None: + """ + Asynchronous function that runs a loop to receive MQTT messages. + + The function first checks if the MQTT connection is established. If not, it awaits + the connection before proceeding. Once the connection is established, the function + enters an infinite loop. + + Within the loop, the function checks if the MQTT client is initialized. If not, it + logs an error message and waits for the client to be initialized before proceeding. + If the client is initialized, the function retrieves a message from the MQTT + message queue. + + If the message payload is `None`, the function logs a warning message and continues + to the next message. Otherwise, it decodes the payload as a UTF-8 string and logs + the received message and topic. + + Finally, the function calls the `_handle_mqtt_msg` method to handle the received + message. + + This function does not take any parameters and does not return any value. + """ if not self.mqtt_connected.is_set(): _LOG.debug("_mqtt_rx_loop awaiting MQTT connection") await self.mqtt_connected.wait() @@ -1102,6 +1248,21 @@ async def _mqtt_rx_loop(self) -> None: await self._handle_mqtt_msg(msg.topic, msg.payload) async def _remove_finished_transient_tasks(self) -> None: + """ + Remove any finished transient tasks from the list of transient tasks. + + This function runs in an infinite loop, sleeping for 1 second in each iteration. + It checks for any finished tasks in the list of transient tasks and removes them. + Once the finished tasks are removed, it gathers the results of those tasks and checks + for any exceptions. If an exception is found, it raises the exception and logs an error + message with the task information. + + Parameters: + None + + Returns: + None + """ while True: await asyncio.sleep(1) finished_tasks = [x for x in self.transient_tasks if x.done()] @@ -1191,6 +1352,19 @@ async def stream_output_loop( self.event_bus.fire(StreamDataSentEvent(stream_conf["name"], data)) async def _main_loop(self) -> None: + """ + Asynchronous main loop function. + + This function is responsible for running the main event loop of the MQTT client. + It handles reconnecting to the MQTT broker if the connection is lost and + manages the execution of critical tasks. + + Parameters: + None + + Returns: + None + """ reconnect = True reconnect_delay = self.config["mqtt"]["reconnect_delay"] reconnects_remaining = None @@ -1219,14 +1393,16 @@ async def _main_loop(self) -> None: except asyncio.CancelledError: break except Exception: # pylint: disable=broad-except - _LOG.exception("Exception in critical task:") + #_LOG.exception("Exception in critical task:") + _LOG.error("Exception in critical task") except asyncio.CancelledError: break except MQTTException: if reconnects_remaining is not None: reconnect = reconnects_remaining > 0 reconnects_remaining -= 1 - _LOG.exception("Connection to MQTT broker failed") + #_LOG.exception("Connection to MQTT broker failed") + _LOG.error("Connection to MQTT broker failed") finally: _LOG.debug("Clearing events and cancelling 'critical_tasks'")