diff --git a/custom_components/xiaomi_home/miot/common.py b/custom_components/xiaomi_home/miot/common.py index 714f158..0ee4f1d 100644 --- a/custom_components/xiaomi_home/miot/common.py +++ b/custom_components/xiaomi_home/miot/common.py @@ -83,6 +83,9 @@ def randomize_int(value: int, ratio: float) -> int: """Randomize an integer value.""" return int(value * (1 - ratio + random.random()*2*ratio)) +def randomize_float(value: float, ratio: float) -> float: + """Randomize a float value.""" + return value * (1 - ratio + random.random()*2*ratio) class MIoTMatcher(MQTTMatcher): """MIoT Pub/Sub topic matcher.""" diff --git a/custom_components/xiaomi_home/miot/miot_client.py b/custom_components/xiaomi_home/miot/miot_client.py index 9c57c34..18a88f5 100644 --- a/custom_components/xiaomi_home/miot/miot_client.py +++ b/custom_components/xiaomi_home/miot/miot_client.py @@ -1089,7 +1089,7 @@ async def __on_miot_lan_state_change(self, state: bool) -> None: handler=self.__on_lan_device_state_changed) for did, info in ( await self._miot_lan.get_dev_list_async()).items(): - self.__on_lan_device_state_changed( + await self.__on_lan_device_state_changed( did=did, state=info, ctx=None) _LOGGER.info('lan device list, %s', self._device_list_lan) self._miot_lan.update_devices(devices={ diff --git a/custom_components/xiaomi_home/miot/miot_lan.py b/custom_components/xiaomi_home/miot/miot_lan.py index 6679328..600afc1 100644 --- a/custom_components/xiaomi_home/miot/miot_lan.py +++ b/custom_components/xiaomi_home/miot/miot_lan.py @@ -53,14 +53,12 @@ from dataclasses import dataclass from enum import Enum, auto import logging -import os -import queue import random import secrets import socket import struct import threading -from typing import Callable, Optional, final +from typing import Any, Callable, Coroutine, Optional, final from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding from cryptography.hazmat.backends import default_backend @@ -68,100 +66,61 @@ # pylint: disable=relative-beyond-top-level from .miot_error import MIoTErrorCode -from .miot_ev import MIoTEventLoop, TimeoutHandle from .miot_network import InterfaceStatus, MIoTNetwork, NetworkInfo from .miot_mdns import MipsService, MipsServiceState from .common import ( - randomize_int, load_yaml_file, gen_absolute_path, MIoTMatcher) + randomize_float, load_yaml_file, gen_absolute_path, MIoTMatcher) _LOGGER = logging.getLogger(__name__) -class MIoTLanCmdType(Enum): - """MIoT lan command.""" - DEINIT = 0 - CALL_API = auto() - SUB_DEVICE_STATE = auto() - UNSUB_DEVICE_STATE = auto() - REG_BROADCAST = auto() - UNREG_BROADCAST = auto() - GET_DEV_LIST = auto() - DEVICE_UPDATE = auto() - DEVICE_DELETE = auto() - NET_INFO_UPDATE = auto() - NET_IFS_UPDATE = auto() - OPTIONS_UPDATE = auto() - - -@dataclass -class MIoTLanCmd: - """MIoT lan command.""" - type_: MIoTLanCmdType - data: any - - @dataclass -class MIoTLanCmdData: - handler: Callable[[dict, any], None] - handler_ctx: any +class _MIoTLanGetDevListData: + handler: Callable[[dict, Any], None] + handler_ctx: Any timeout_ms: int @dataclass -class MIoTLanGetDevListData(MIoTLanCmdData): - ... - - -@dataclass -class MIoTLanCallApiData(MIoTLanCmdData): - did: str - msg: dict - - -class MIoTLanSendBroadcastData(MIoTLanCallApiData): - ... - - -@dataclass -class MIoTLanUnregisterBroadcastData: +class _MIoTLanUnregisterBroadcastData: key: str @dataclass -class MIoTLanRegisterBroadcastData: +class _MIoTLanRegisterBroadcastData: key: str - handler: Callable[[dict, any], None] - handler_ctx: any + handler: Callable[[dict, Any], None] + handler_ctx: Any @dataclass -class MIoTLanUnsubDeviceState: +class _MIoTLanUnsubDeviceData: key: str @dataclass -class MIoTLanSubDeviceState: +class _MIoTLanSubDeviceData: key: str - handler: Callable[[str, dict, any], None] - handler_ctx: any + handler: Callable[[str, dict, Any], Coroutine] + handler_ctx: Any @dataclass -class MIoTLanNetworkUpdateData: +class _MIoTLanNetworkUpdateData: status: InterfaceStatus if_name: str @dataclass -class MIoTLanRequestData: +class _MIoTLanRequestData: msg_id: int - handler: Callable[[dict, any], None] - handler_ctx: any - timeout: TimeoutHandle + handler: Optional[Callable[[dict, Any], None]] + handler_ctx: Any + timeout: Optional[asyncio.TimerHandle] -class MIoTLanDeviceState(Enum): +class _MIoTLanDeviceState(Enum): FRESH = 0 PING1 = auto() PING2 = auto() @@ -169,18 +128,18 @@ class MIoTLanDeviceState(Enum): DEAD = auto() -class MIoTLanDevice: +class _MIoTLanDevice: """MIoT lan device.""" # pylint: disable=unused-argument OT_HEADER: int = 0x2131 OT_HEADER_LEN: int = 32 NETWORK_UNSTABLE_CNT_TH: int = 10 - NETWORK_UNSTABLE_TIME_TH: int = 120000 - NETWORK_UNSTABLE_RESUME_TH: int = 300000 - FAST_PING_INTERVAL: int = 5000 - CONSTRUCT_STATE_PENDING: int = 15000 - KA_INTERVAL_MIN = 10000 - KA_INTERVAL_MAX = 50000 + NETWORK_UNSTABLE_TIME_TH: float = 120 + NETWORK_UNSTABLE_RESUME_TH: float = 300 + FAST_PING_INTERVAL: float = 5 + CONSTRUCT_STATE_PENDING: float = 15 + KA_INTERVAL_MIN: float = 10 + KA_INTERVAL_MAX: float = 50 did: str token: bytes @@ -192,19 +151,25 @@ class MIoTLanDevice: sub_ts: int supported_wildcard_sub: bool - _manager: any + _manager: 'MIoTLan' _if_name: Optional[str] _sub_locked: bool - _state: MIoTLanDeviceState + _state: _MIoTLanDeviceState _online: bool - _online_offline_history: list[dict[str, any]] - _online_offline_timer: Optional[TimeoutHandle] + _online_offline_history: list[dict[str, Any]] + _online_offline_timer: Optional[asyncio.TimerHandle] - _ka_timer: TimeoutHandle - _ka_internal: int + _ka_timer: Optional[asyncio.TimerHandle] + _ka_internal: float + +# All functions SHOULD be called from the internal loop def __init__( - self, manager: any, did: str, token: str, ip: Optional[str] = None + self, + manager: 'MIoTLan', + did: str, + token: str, + ip: Optional[str] = None ) -> None: self._manager: MIoTLan = manager self.did = did @@ -220,17 +185,17 @@ def __init__( self.supported_wildcard_sub = False self._if_name = None self._sub_locked = False - self._state = MIoTLanDeviceState.DEAD + self._state = _MIoTLanDeviceState.DEAD self._online = False self._online_offline_history = [] self._online_offline_timer = None - def ka_init_handler(ctx: any) -> None: + def ka_init_handler() -> None: self._ka_internal = self.KA_INTERVAL_MIN - self.__update_keep_alive(state=MIoTLanDeviceState.DEAD) - self._ka_timer = self._manager.mev.set_timeout( - randomize_int(self.CONSTRUCT_STATE_PENDING, 0.5), - ka_init_handler, None) + self.__update_keep_alive(state=_MIoTLanDeviceState.DEAD) + self._ka_timer = self._manager.internal_loop.call_later( + randomize_float(self.CONSTRUCT_STATE_PENDING, 0.5), + ka_init_handler,) _LOGGER.debug('miot lan device add, %s', self.did) def keep_alive(self, ip: str, if_name: str) -> None: @@ -239,7 +204,7 @@ def keep_alive(self, ip: str, if_name: str) -> None: self._if_name = if_name _LOGGER.info( 'device if_name change, %s, %s', self._if_name, self.did) - self.__update_keep_alive(state=MIoTLanDeviceState.FRESH) + self.__update_keep_alive(state=_MIoTLanDeviceState.FRESH) @property def online(self) -> bool: @@ -342,11 +307,11 @@ def unsubscribe(self) -> None: def on_delete(self) -> None: if self._ka_timer: - self._manager.mev.clear_timeout(self._ka_timer) + self._ka_timer.cancel() + self._ka_timer = None if self._online_offline_timer: - self._manager.mev.clear_timeout(self._online_offline_timer) - self._manager = None - self.cipher = None + self._online_offline_timer.cancel() + self._online_offline_timer = None _LOGGER.debug('miot lan device delete, %s', self.did) def update_info(self, info: dict) -> None: @@ -379,7 +344,7 @@ def __subscribe_handler(self, msg: dict, sub_ts: int) -> None: 'online': self._online, 'push_available': self.subscribed}) _LOGGER.info('subscribe success, %s, %s', self._if_name, self.did) - def __unsubscribe_handler(self, msg: dict, ctx: any) -> None: + def __unsubscribe_handler(self, msg: dict, ctx: Any) -> None: if ( 'result' not in msg or 'code' not in msg['result'] @@ -389,42 +354,49 @@ def __unsubscribe_handler(self, msg: dict, ctx: any) -> None: return _LOGGER.info('unsubscribe success, %s, %s', self._if_name, self.did) - def __update_keep_alive(self, state: MIoTLanDeviceState) -> None: - last_state: MIoTLanDeviceState = self._state + def __update_keep_alive(self, state: _MIoTLanDeviceState) -> None: + last_state: _MIoTLanDeviceState = self._state self._state = state - if self._state != MIoTLanDeviceState.FRESH: + if self._state != _MIoTLanDeviceState.FRESH: _LOGGER.debug('device status, %s, %s', self.did, self._state) if self._ka_timer: - self._manager.mev.clear_timeout(self._ka_timer) + self._ka_timer.cancel() self._ka_timer = None match state: - case MIoTLanDeviceState.FRESH: - if last_state == MIoTLanDeviceState.DEAD: + case _MIoTLanDeviceState.FRESH: + if last_state == _MIoTLanDeviceState.DEAD: self._ka_internal = self.KA_INTERVAL_MIN self.__change_online(True) - self._ka_timer = self._manager.mev.set_timeout( + self._ka_timer = self._manager.internal_loop.call_later( self.__get_next_ka_timeout(), self.__update_keep_alive, - MIoTLanDeviceState.PING1) + _MIoTLanDeviceState.PING1) case ( - MIoTLanDeviceState.PING1 - | MIoTLanDeviceState.PING2 - | MIoTLanDeviceState.PING3 + _MIoTLanDeviceState.PING1 + | _MIoTLanDeviceState.PING2 + | _MIoTLanDeviceState.PING3 ): - self._manager.ping(if_name=self._if_name, target_ip=self.ip) - # Fast ping - self._ka_timer = self._manager.mev.set_timeout( + # Set the timer first to avoid Any early returns + self._ka_timer = self._manager.internal_loop.call_later( self.FAST_PING_INTERVAL, self.__update_keep_alive, - MIoTLanDeviceState(state.value+1)) - case MIoTLanDeviceState.DEAD: - if last_state == MIoTLanDeviceState.PING3: + _MIoTLanDeviceState(state.value+1)) + # Fast ping + if self._if_name is None: + _LOGGER.error('if_name is Not set for device, %s', self.did) + return + if self.ip is None: + _LOGGER.error('ip is Not set for device, %s', self.did) + return + self._manager.ping(if_name=self._if_name, target_ip=self.ip) + case _MIoTLanDeviceState.DEAD: + if last_state == _MIoTLanDeviceState.PING3: self._ka_internal = self.KA_INTERVAL_MIN self.__change_online(False) case _: _LOGGER.error('invalid state, %s', state) - def __get_next_ka_timeout(self) -> int: + def __get_next_ka_timeout(self) -> float: self._ka_internal = min(self._ka_internal*2, self.KA_INTERVAL_MAX) - return randomize_int(self._ka_internal, 0.1) + return randomize_float(self._ka_internal, 0.1) def __change_online(self, online: bool) -> None: _LOGGER.info('change online, %s, %s', self.did, online) @@ -433,7 +405,8 @@ def __change_online(self, online: bool) -> None: if len(self._online_offline_history) > self.NETWORK_UNSTABLE_CNT_TH: self._online_offline_history.pop(0) if self._online_offline_timer: - self._manager.mev.clear_timeout(self._online_offline_timer) + self._online_offline_timer.cancel() + self._online_offline_timer = None if not online: self.online = False else: @@ -446,11 +419,12 @@ def __change_online(self, online: bool) -> None: self.online = True else: _LOGGER.info('unstable device detected, %s', self.did) - self._online_offline_timer = self._manager.mev.set_timeout( - self.NETWORK_UNSTABLE_RESUME_TH, - self.__online_resume_handler, None) + self._online_offline_timer = \ + self._manager.internal_loop.call_later( + self.NETWORK_UNSTABLE_RESUME_TH, + self.__online_resume_handler) - def __online_resume_handler(self, ctx: any) -> None: + def __online_resume_handler(self) -> None: _LOGGER.info('unstable resume threshold past, %s', self.did) self.online = True @@ -470,8 +444,8 @@ class MIoTLan: OT_MSG_LEN: int = 1400 OT_SUPPORT_WILDCARD_SUB: int = 0xFE - OT_PROBE_INTERVAL_MIN: int = 5000 - OT_PROBE_INTERVAL_MAX: int = 45000 + OT_PROBE_INTERVAL_MIN: float = 5 + OT_PROBE_INTERVAL_MAX: float = 45 PROFILE_MODELS_FILE: str = 'lan/profile_models.yaml' @@ -480,43 +454,44 @@ class MIoTLan: _network: MIoTNetwork _mips_service: MipsService _enable_subscribe: bool - _lan_devices: dict[str, MIoTLanDevice] + _lan_devices: dict[str, _MIoTLanDevice] _virtual_did: str _probe_msg: bytes _write_buffer: bytearray _read_buffer: bytearray - _mev: MIoTEventLoop + _internal_loop: asyncio.AbstractEventLoop _thread: threading.Thread - _queue: queue.Queue - _cmd_event_fd: os.eventfd _available_net_ifs: set[str] _broadcast_socks: dict[str, socket.socket] _local_port: Optional[int] - _scan_timer: TimeoutHandle - _last_scan_interval: Optional[int] + _scan_timer: Optional[asyncio.TimerHandle] + _last_scan_interval: Optional[float] _msg_id_counter: int - _pending_requests: dict[int, MIoTLanRequestData] + _pending_requests: dict[int, _MIoTLanRequestData] _device_msg_matcher: MIoTMatcher - _device_state_sub_map: dict[str, MIoTLanSubDeviceState] - _reply_msg_buffer: dict[str, TimeoutHandle] + _device_state_sub_map: dict[str, _MIoTLanSubDeviceData] + _reply_msg_buffer: dict[str, asyncio.TimerHandle] - _lan_state_sub_map: dict[str, Callable[[bool], asyncio.Future]] + _lan_state_sub_map: dict[str, Callable[[bool], Coroutine]] _lan_ctrl_vote_map: dict[str, bool] _profile_models: dict[str, dict] + _init_lock: asyncio.Lock _init_done: bool +# The following should be called from the main loop + def __init__( - self, - net_ifs: list[str], - network: MIoTNetwork, - mips_service: MipsService, - enable_subscribe: bool = False, - virtual_did: Optional[int] = None, - loop: Optional[asyncio.AbstractEventLoop] = None + self, + net_ifs: list[str], + network: MIoTNetwork, + mips_service: MipsService, + enable_subscribe: bool = False, + virtual_did: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None ) -> None: if not network: raise ValueError('network is required') @@ -526,13 +501,16 @@ def __init__( self._net_ifs = set(net_ifs) self._network = network self._network.sub_network_info( - key='miot_lan', handler=self.__on_network_info_change) + key='miot_lan', + handler=self.__on_network_info_change_external_async) self._mips_service = mips_service self._mips_service.sub_service_change( key='miot_lan', group_id='*', handler=self.__on_mips_service_change) self._enable_subscribe = enable_subscribe - self._virtual_did = virtual_did or str(secrets.randbits(64)) + self._virtual_did = str(virtual_did) \ + if (virtual_did is not None) \ + else str(secrets.randbits(64)) # Init socket probe message probe_bytes = bytearray(self.OT_PROBE_LEN) probe_bytes[:20] = ( @@ -558,6 +536,7 @@ def __init__( self._lan_state_sub_map = {} self._lan_ctrl_vote_map = {} + self._init_lock = asyncio.Lock() self._init_done = False if ( @@ -574,63 +553,71 @@ def virtual_did(self) -> str: return self._virtual_did @property - def mev(self) -> MIoTEventLoop: - return self._mev + def internal_loop(self) -> asyncio.AbstractEventLoop: + return self._internal_loop @property def init_done(self) -> bool: return self._init_done async def init_async(self) -> None: - if self._init_done: - _LOGGER.info('miot lan already init') - return - if len(self._net_ifs) == 0: - _LOGGER.info('no net_ifs') - return - if not any(self._lan_ctrl_vote_map.values()): - _LOGGER.info('no vote for lan ctrl') - return - if len(self._mips_service.get_services()) > 0: - _LOGGER.info('central hub gateway service exist') - return - for if_name in list(self._network.network_info.keys()): - self._available_net_ifs.add(if_name) - if len(self._available_net_ifs) == 0: - _LOGGER.info('no available net_ifs') - return - if self._net_ifs.isdisjoint(self._available_net_ifs): - _LOGGER.info('no valid net_ifs') - return - try: - self._profile_models = await self._main_loop.run_in_executor( - None, load_yaml_file, - gen_absolute_path(self.PROFILE_MODELS_FILE)) - except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('load profile models error, %s', err) - self._profile_models = {} - self._mev = MIoTEventLoop() - self._queue = queue.Queue() - self._cmd_event_fd = os.eventfd(0, os.O_NONBLOCK) - self._mev.set_read_handler( - self._cmd_event_fd, self.__cmd_read_handler, None) - self._thread = threading.Thread(target=self.__lan_thread_handler) - self._thread.name = 'miot_lan' - self._thread.daemon = True - self._thread.start() - self._init_done = True - for handler in list(self._lan_state_sub_map.values()): - self._main_loop.create_task(handler(True)) - _LOGGER.info( - 'miot lan init, %s ,%s', self._net_ifs, self._available_net_ifs) + # Avoid race condition + async with self._init_lock: + if self._init_done: + _LOGGER.info('miot lan already init') + return + if len(self._net_ifs) == 0: + _LOGGER.info('no net_ifs') + return + if not any(self._lan_ctrl_vote_map.values()): + _LOGGER.info('no vote for lan ctrl') + return + if len(self._mips_service.get_services()) > 0: + _LOGGER.info('central hub gateway service exist') + return + for if_name in list(self._network.network_info.keys()): + self._available_net_ifs.add(if_name) + if len(self._available_net_ifs) == 0: + _LOGGER.info('no available net_ifs') + return + if self._net_ifs.isdisjoint(self._available_net_ifs): + _LOGGER.info('no valid net_ifs') + return + try: + self._profile_models = await self._main_loop.run_in_executor( + None, load_yaml_file, + gen_absolute_path(self.PROFILE_MODELS_FILE)) + except Exception as err: # pylint: disable=broad-exception-caught + _LOGGER.error('load profile models error, %s', err) + self._profile_models = {} + self._internal_loop = asyncio.new_event_loop() + # All tasks meant for the internal loop should happen in this thread + self._thread = threading.Thread(target=self.__internal_loop_thread) + self._thread.name = 'miot_lan' + self._thread.daemon = True + self._thread.start() + self._init_done = True + for handler in list(self._lan_state_sub_map.values()): + self._main_loop.create_task(handler(True)) + _LOGGER.info( + 'miot lan init, %s ,%s', self._net_ifs, self._available_net_ifs) + + def __internal_loop_thread(self) -> None: + _LOGGER.info('miot lan thread start') + self.__init_socket() + self._scan_timer = self._internal_loop.call_later( + int(3*random.random()), self.__scan_devices) + self._internal_loop.run_forever() + _LOGGER.info('miot lan thread exit') async def deinit_async(self) -> None: if not self._init_done: _LOGGER.info('miot lan not init') return self._init_done = False - self.__lan_send_cmd(MIoTLanCmdType.DEINIT, None) + self._internal_loop.call_soon_threadsafe(self.__deinit) self._thread.join() + self._internal_loop.close() self._profile_models = {} self._lan_devices = {} @@ -670,9 +657,9 @@ async def update_net_ifs_async(self, net_ifs: list[str]) -> None: self._net_ifs = set(net_ifs) await self.init_async() return - self.__lan_send_cmd( - cmd=MIoTLanCmdType.NET_IFS_UPDATE, - data=net_ifs) + self._internal_loop.call_soon_threadsafe( + self.__update_net_ifs, + net_ifs) async def vote_for_lan_ctrl_async(self, key: str, vote: bool) -> None: _LOGGER.info('vote for lan ctrl, %s, %s', key, vote) @@ -687,25 +674,24 @@ async def update_subscribe_option(self, enable_subscribe: bool) -> None: if not self._init_done: self._enable_subscribe = enable_subscribe return - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.OPTIONS_UPDATE, - data={ - 'enable_subscribe': enable_subscribe, }) + self._internal_loop.call_soon_threadsafe( + self.__update_subscribe_option, + {'enable_subscribe': enable_subscribe}) def update_devices(self, devices: dict[str, dict]) -> bool: _LOGGER.info('update devices, %s', devices) - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.DEVICE_UPDATE, - data=devices) + self._internal_loop.call_soon_threadsafe( + self.__update_devices, devices) + return True def delete_devices(self, devices: list[str]) -> bool: _LOGGER.info('delete devices, %s', devices) - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.DEVICE_DELETE, - data=devices) + self._internal_loop.call_soon_threadsafe( + self.__delete_devices, devices) + return True def sub_lan_state( - self, key: str, handler: Callable[[bool], asyncio.Future] + self, key: str, handler: Callable[[bool], Coroutine] ) -> None: self._lan_state_sub_map[key] = handler @@ -714,76 +700,99 @@ def unsub_lan_state(self, key: str) -> None: @final def sub_device_state( - self, key: str, handler: Callable[[str, dict, any], None], - handler_ctx: any = None + self, key: str, handler: Callable[[str, dict, Any], Coroutine], + handler_ctx: Any = None ) -> bool: - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.SUB_DEVICE_STATE, - data=MIoTLanSubDeviceState( + self._internal_loop.call_soon_threadsafe( + self.__sub_device_state, + _MIoTLanSubDeviceData( key=key, handler=handler, handler_ctx=handler_ctx)) + return True @final def unsub_device_state(self, key: str) -> bool: - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.UNSUB_DEVICE_STATE, - data=MIoTLanUnsubDeviceState(key=key)) + self._internal_loop.call_soon_threadsafe( + self.__unsub_device_state, _MIoTLanUnsubDeviceData(key=key)) + return True @final def sub_prop( - self, did: str, handler: Callable[[dict, any], None], - siid: int = None, piid: int = None, handler_ctx: any = None + self, + did: str, + handler: Callable[[dict, Any], None], + siid: Optional[int] = None, + piid: Optional[int] = None, + handler_ctx: Any = None ) -> bool: if not self._enable_subscribe: return False key = ( f'{did}/p/' f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.REG_BROADCAST, - data=MIoTLanRegisterBroadcastData( + self._internal_loop.call_soon_threadsafe( + self.__sub_broadcast, + _MIoTLanRegisterBroadcastData( key=key, handler=handler, handler_ctx=handler_ctx)) + return True @final - def unsub_prop(self, did: str, siid: int = None, piid: int = None) -> bool: + def unsub_prop( + self, + did: str, + siid: Optional[int] = None, + piid: Optional[int] = None + ) -> bool: if not self._enable_subscribe: return False key = ( f'{did}/p/' f'{"#" if siid is None or piid is None else f"{siid}/{piid}"}') - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.UNREG_BROADCAST, - data=MIoTLanUnregisterBroadcastData(key=key)) + self._internal_loop.call_soon_threadsafe( + self.__unsub_broadcast, + _MIoTLanUnregisterBroadcastData(key=key)) + return True @final def sub_event( - self, did: str, handler: Callable[[dict, any], None], - siid: int = None, eiid: int = None, handler_ctx: any = None + self, + did: str, + handler: Callable[[dict, Any], None], + siid: Optional[int] = None, + eiid: Optional[int] = None, + handler_ctx: Any = None ) -> bool: if not self._enable_subscribe: return False key = ( f'{did}/e/' f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.REG_BROADCAST, - data=MIoTLanRegisterBroadcastData( + self._internal_loop.call_soon_threadsafe( + self.__sub_broadcast, + _MIoTLanRegisterBroadcastData( key=key, handler=handler, handler_ctx=handler_ctx)) + return True @final - def unsub_event(self, did: str, siid: int = None, eiid: int = None) -> bool: + def unsub_event( + self, + did: str, + siid: Optional[int] = None, + eiid: Optional[int] = None + ) -> bool: if not self._enable_subscribe: return False key = ( f'{did}/e/' f'{"#" if siid is None or eiid is None else f"{siid}/{eiid}"}') - return self.__lan_send_cmd( - cmd=MIoTLanCmdType.UNREG_BROADCAST, - data=MIoTLanUnregisterBroadcastData(key=key)) + self._internal_loop.call_soon_threadsafe( + self.__unsub_broadcast, + _MIoTLanUnregisterBroadcastData(key=key)) + return True @final async def get_prop_async( self, did: str, siid: int, piid: int, timeout_ms: int = 10000 - ) -> any: + ) -> Any: result_obj = await self.__call_api_async( did=did, msg={ 'method': 'get_properties', @@ -801,7 +810,7 @@ async def get_prop_async( @final async def set_prop_async( - self, did: str, siid: int, piid: int, value: any, + self, did: str, siid: int, piid: int, value: Any, timeout_ms: int = 10000 ) -> dict: result_obj = await self.__call_api_async( @@ -857,18 +866,68 @@ def get_device_list_handler(msg: dict, fut: asyncio.Future): fut.set_result, msg) fut: asyncio.Future = self._main_loop.create_future() - if self.__lan_send_cmd( - MIoTLanCmdType.GET_DEV_LIST, - MIoTLanGetDevListData( + self._internal_loop.call_soon_threadsafe( + self.__get_dev_list, + _MIoTLanGetDevListData( handler=get_device_list_handler, handler_ctx=fut, - timeout_ms=timeout_ms)): - return await fut - _LOGGER.error('get_dev_list_async error, send cmd failed') - fut.set_result({}) + timeout_ms=timeout_ms)) + return await fut + + async def __call_api_async( + self, did: str, msg: dict, timeout_ms: int = 10000 + ) -> dict: + def call_api_handler(msg: dict, fut: asyncio.Future): + self._main_loop.call_soon_threadsafe( + fut.set_result, msg) + + fut: asyncio.Future = self._main_loop.create_future() + self._internal_loop.call_soon_threadsafe( + self.__call_api, did, msg, call_api_handler, fut, timeout_ms) return await fut - def ping(self, if_name: str, target_ip: str) -> None: + async def __on_network_info_change_external_async( + self, + status: InterfaceStatus, + info: NetworkInfo + ) -> None: + _LOGGER.info( + 'on network info change, status: %s, info: %s', status, info) + available_net_ifs = set() + for if_name in list(self._network.network_info.keys()): + available_net_ifs.add(if_name) + if len(available_net_ifs) == 0: + await self.deinit_async() + self._available_net_ifs = available_net_ifs + return + if self._net_ifs.isdisjoint(available_net_ifs): + _LOGGER.info('no valid net_ifs') + await self.deinit_async() + self._available_net_ifs = available_net_ifs + return + if not self._init_done: + self._available_net_ifs = available_net_ifs + await self.init_async() + return + self._internal_loop.call_soon_threadsafe( + self.__on_network_info_chnage, + _MIoTLanNetworkUpdateData(status=status, if_name=info.name)) + + async def __on_mips_service_change( + self, group_id: str, state: MipsServiceState, data: dict + ) -> None: + _LOGGER.info( + 'on mips service change, %s, %s, %s', group_id, state, data) + if len(self._mips_service.get_services()) > 0: + _LOGGER.info('find central service, deinit miot lan') + await self.deinit_async() + else: + _LOGGER.info('no central service, init miot lan') + await self.init_async() + +# The folowing methods SHOULD ONLY be called in the internal loop + + def ping(self, if_name: str | None, target_ip: str) -> None: if not target_ip: return self.__sendto( @@ -878,13 +937,13 @@ def ping(self, if_name: str, target_ip: str) -> None: def send2device( self, did: str, msg: dict, - handler: Optional[Callable[[dict, any], None]] = None, - handler_ctx: any = None, + handler: Optional[Callable[[dict, Any], None]] = None, + handler_ctx: Any = None, timeout_ms: Optional[int] = None ) -> None: if timeout_ms and not handler: raise ValueError('handler is required when timeout_ms is set') - device: MIoTLanDevice = self._lan_devices.get(did) + device: _MIoTLanDevice | None = self._lan_devices.get(did) if not device: raise ValueError('invalid device') if not device.cipher: @@ -900,7 +959,7 @@ def send2device( did=did, offset=int(time.time())-device.offset) - return self.make_request( + return self.__make_request( msg_id=in_msg['id'], msg=self._write_buffer[0: msg_len], if_name=device.if_name, @@ -909,33 +968,33 @@ def send2device( handler_ctx=handler_ctx, timeout_ms=timeout_ms) - def make_request( + def __make_request( self, msg_id: int, msg: bytearray, if_name: str, ip: str, - handler: Callable[[dict, any], None], - handler_ctx: any = None, + handler: Optional[Callable[[dict, Any], None]], + handler_ctx: Any = None, timeout_ms: Optional[int] = None ) -> None: - def request_timeout_handler(req_data: MIoTLanRequestData): + def request_timeout_handler(req_data: _MIoTLanRequestData): self._pending_requests.pop(req_data.msg_id, None) - if req_data: + if req_data and req_data.handler: req_data.handler({ 'code': MIoTErrorCode.CODE_TIMEOUT.value, 'error': 'timeout'}, req_data.handler_ctx) - timer: Optional[TimeoutHandle] = None - request_data = MIoTLanRequestData( + timer: Optional[asyncio.TimerHandle] = None + request_data = _MIoTLanRequestData( msg_id=msg_id, handler=handler, handler_ctx=handler_ctx, timeout=timer) if timeout_ms: - timer = self._mev.set_timeout( - timeout_ms, request_timeout_handler, request_data) + timer = self._internal_loop.call_later( + timeout_ms/1000, request_timeout_handler, request_data) request_data.timeout = timer self._pending_requests[msg_id] = request_data self.__sendto(if_name=if_name, data=msg, address=ip, port=self.OT_PORT) @@ -954,175 +1013,137 @@ def __gen_msg_id(self) -> int: self._msg_id_counter = 1 return self._msg_id_counter - def __lan_send_cmd(self, cmd: MIoTLanCmd, data: any) -> bool: + def __call_api( + self, + did: str, + msg: dict, + handler: Callable, + handler_ctx: Any, + timeout_ms: int = 10000 + ) -> None: try: - self._queue.put(MIoTLanCmd(type_=cmd, data=data)) - os.eventfd_write(self._cmd_event_fd, 1) - return True + self.send2device( + did=did, + msg={'from': 'ha.xiaomi_home', **msg}, + handler=handler, + handler_ctx=handler_ctx, + timeout_ms=timeout_ms) except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('send cmd error, %s, %s', cmd, err) - return False - - async def __call_api_async( - self, did: str, msg: dict, timeout_ms: int = 10000 - ) -> dict: - def call_api_handler(msg: dict, fut: asyncio.Future): - self._main_loop.call_soon_threadsafe( - fut.set_result, msg) - - fut: asyncio.Future = self._main_loop.create_future() - if self.__lan_send_cmd( - cmd=MIoTLanCmdType.CALL_API, - data=MIoTLanCallApiData( - did=did, - msg=msg, - handler=call_api_handler, - handler_ctx=fut, - timeout_ms=timeout_ms)): - return await fut - - fut.set_result({ - 'code': MIoTErrorCode.CODE_UNAVAILABLE.value, - 'error': 'send cmd error'}) - return await fut - - def __lan_thread_handler(self) -> None: - _LOGGER.info('miot lan thread start') - self.__init_socket() - # Create scan devices timer - self._scan_timer = self._mev.set_timeout( - int(3000*random.random()), self.__scan_devices, None) - self._mev.loop_forever() - _LOGGER.info('miot lan thread exit') - - def __cmd_read_handler(self, ctx: any) -> None: - fd_value = os.eventfd_read(self._cmd_event_fd) - if fd_value == 0: - return - while not self._queue.empty(): - mips_cmd: MIoTLanCmd = self._queue.get(block=False) - if mips_cmd.type_ == MIoTLanCmdType.CALL_API: - call_api_data: MIoTLanCallApiData = mips_cmd.data - try: - self.send2device( - did=call_api_data.did, - msg={'from': 'ha.xiaomi_home', **call_api_data.msg}, - handler=call_api_data.handler, - handler_ctx=call_api_data.handler_ctx, - timeout_ms=call_api_data.timeout_ms) - except Exception as err: # pylint: disable=broad-exception-caught - _LOGGER.error('send2device error, %s', err) - call_api_data.handler({ - 'code': MIoTErrorCode.CODE_INTERNAL_ERROR.value, - 'error': str(err)}, - call_api_data.handler_ctx) - elif mips_cmd.type_ == MIoTLanCmdType.SUB_DEVICE_STATE: - sub_data: MIoTLanSubDeviceState = mips_cmd.data - self._device_state_sub_map[sub_data.key] = sub_data - elif mips_cmd.type_ == MIoTLanCmdType.UNSUB_DEVICE_STATE: - sub_data: MIoTLanUnsubDeviceState = mips_cmd.data - self._device_state_sub_map.pop(sub_data.key, None) - elif mips_cmd.type_ == MIoTLanCmdType.REG_BROADCAST: - reg_data: MIoTLanRegisterBroadcastData = mips_cmd.data - self._device_msg_matcher[reg_data.key] = reg_data - _LOGGER.debug('lan register broadcast, %s', reg_data.key) - elif mips_cmd.type_ == MIoTLanCmdType.UNREG_BROADCAST: - unreg_data: MIoTLanUnregisterBroadcastData = mips_cmd.data - if self._device_msg_matcher.get(topic=unreg_data.key): - del self._device_msg_matcher[unreg_data.key] - _LOGGER.debug('lan unregister broadcast, %s', unreg_data.key) - elif mips_cmd.type_ == MIoTLanCmdType.GET_DEV_LIST: - get_dev_list_data: MIoTLanGetDevListData = mips_cmd.data - dev_list = { - device.did: { - 'online': device.online, - 'push_available': device.subscribed - } - for device in self._lan_devices.values() - if device.online} - get_dev_list_data.handler( - dev_list, get_dev_list_data.handler_ctx) - elif mips_cmd.type_ == MIoTLanCmdType.DEVICE_UPDATE: - devices: dict[str, dict] = mips_cmd.data - for did, info in devices.items(): - # did MUST be digit(UINT64) - if not did.isdigit(): - _LOGGER.info('invalid did, %s', did) - continue - if ( - 'model' not in info - or info['model'] in self._profile_models): - # Do not support the local control of - # Profile device for the time being - _LOGGER.info( - 'model not support local ctrl, %s, %s', - did, info.get('model')) - continue - if did not in self._lan_devices: - if 'token' not in info: - _LOGGER.error( - 'token not found, %s, %s', did, info) - continue - if len(info['token']) != 32: - _LOGGER.error( - 'invalid device token, %s, %s', did, info) - continue - self._lan_devices[did] = MIoTLanDevice( - manager=self, did=did, token=info['token'], - ip=info.get('ip', None)) - else: - self._lan_devices[did].update_info(info) - elif mips_cmd.type_ == MIoTLanCmdType.DEVICE_DELETE: - device_dids: list[str] = mips_cmd.data - for did in device_dids: - lan_device = self._lan_devices.pop(did, None) - if not lan_device: - continue - lan_device.on_delete() - elif mips_cmd.type_ == MIoTLanCmdType.NET_INFO_UPDATE: - net_data: MIoTLanNetworkUpdateData = mips_cmd.data - if net_data.status == InterfaceStatus.ADD: - self._available_net_ifs.add(net_data.if_name) - if net_data.if_name in self._net_ifs: - self.__create_socket(if_name=net_data.if_name) - elif net_data.status == InterfaceStatus.REMOVE: - self._available_net_ifs.remove(net_data.if_name) - self.__destroy_socket(if_name=net_data.if_name) - elif mips_cmd.type_ == MIoTLanCmdType.NET_IFS_UPDATE: - net_ifs: list[str] = mips_cmd.data - if self._net_ifs != set(net_ifs): - self._net_ifs = set(net_ifs) - for if_name in self._net_ifs: - self.__create_socket(if_name=if_name) - for if_name in list(self._broadcast_socks.keys()): - if if_name not in self._net_ifs: - self.__destroy_socket(if_name=if_name) - elif mips_cmd.type_ == MIoTLanCmdType.OPTIONS_UPDATE: - options: dict = mips_cmd.data - if 'enable_subscribe' in options: - if options['enable_subscribe'] != self._enable_subscribe: - self._enable_subscribe = options['enable_subscribe'] - if not self._enable_subscribe: - # Unsubscribe all - for device in self._lan_devices.values(): - device.unsubscribe() - elif mips_cmd.type_ == MIoTLanCmdType.DEINIT: - # stop the thread - if self._scan_timer: - self._mev.clear_timeout(self._scan_timer) - self._scan_timer = None - for device in self._lan_devices.values(): - device.on_delete() - self._lan_devices.clear() - for req_data in self._pending_requests.values(): - self._mev.clear_timeout(req_data.timeout) - self._pending_requests.clear() - for timer in self._reply_msg_buffer.values(): - self._mev.clear_timeout(timer) - self._reply_msg_buffer.clear() - self._device_msg_matcher = MIoTMatcher() - self.__deinit_socket() - self._mev.loop_stop() + _LOGGER.error('send2device error, %s', err) + handler({ + 'code': MIoTErrorCode.CODE_INTERNAL_ERROR.value, + 'error': str(err)}, + handler_ctx) + + def __sub_device_state(self, data: _MIoTLanSubDeviceData) -> None: + self._device_state_sub_map[data.key] = data + + def __unsub_device_state(self, data: _MIoTLanUnsubDeviceData) -> None: + self._device_state_sub_map.pop(data.key, None) + + def __sub_broadcast(self, data: _MIoTLanRegisterBroadcastData) -> None: + self._device_msg_matcher[data.key] = data + _LOGGER.debug('lan register broadcast, %s', data.key) + + def __unsub_broadcast(self, data: _MIoTLanUnregisterBroadcastData) -> None: + if self._device_msg_matcher.get(topic=data.key): + del self._device_msg_matcher[data.key] + _LOGGER.debug('lan unregister broadcast, %s', data.key) + + def __get_dev_list(self, data: _MIoTLanGetDevListData) -> None: + dev_list = { + device.did: { + 'online': device.online, + 'push_available': device.subscribed + } + for device in self._lan_devices.values() + if device.online} + data.handler( + dev_list, data.handler_ctx) + + def __update_devices(self, devices: dict[str, dict]) -> None: + for did, info in devices.items(): + # did MUST be digit(UINT64) + if not did.isdigit(): + _LOGGER.info('invalid did, %s', did) + continue + if ( + 'model' not in info + or info['model'] in self._profile_models): + # Do not support the local control of + # Profile device for the time being + _LOGGER.info( + 'model not support local ctrl, %s, %s', + did, info.get('model')) + continue + if did not in self._lan_devices: + if 'token' not in info: + _LOGGER.error( + 'token not found, %s, %s', did, info) + continue + if len(info['token']) != 32: + _LOGGER.error( + 'invalid device token, %s, %s', did, info) + continue + self._lan_devices[did] = _MIoTLanDevice( + manager=self, did=did, token=info['token'], + ip=info.get('ip', None)) + else: + self._lan_devices[did].update_info(info) + + def __delete_devices(self, devices: list[str]) -> None: + for did in devices: + lan_device = self._lan_devices.pop(did, None) + if not lan_device: + continue + lan_device.on_delete() + + def __on_network_info_chnage(self, data: _MIoTLanNetworkUpdateData) -> None: + if data.status == InterfaceStatus.ADD: + self._available_net_ifs.add(data.if_name) + if data.if_name in self._net_ifs: + self.__create_socket(if_name=data.if_name) + elif data.status == InterfaceStatus.REMOVE: + self._available_net_ifs.remove(data.if_name) + self.__destroy_socket(if_name=data.if_name) + + def __update_net_ifs(self, net_ifs: list[str]) -> None: + if self._net_ifs != set(net_ifs): + self._net_ifs = set(net_ifs) + for if_name in self._net_ifs: + self.__create_socket(if_name=if_name) + for if_name in list(self._broadcast_socks.keys()): + if if_name not in self._net_ifs: + self.__destroy_socket(if_name=if_name) + + def __update_subscribe_option(self, options: dict) -> None: + if 'enable_subscribe' in options: + if options['enable_subscribe'] != self._enable_subscribe: + self._enable_subscribe = options['enable_subscribe'] + if not self._enable_subscribe: + # Unsubscribe all + for device in self._lan_devices.values(): + device.unsubscribe() + + def __deinit(self) -> None: + # Release all resources + if self._scan_timer: + self._scan_timer.cancel() + self._scan_timer = None + for device in self._lan_devices.values(): + device.on_delete() + self._lan_devices.clear() + for req_data in self._pending_requests.values(): + if req_data.timeout: + req_data.timeout.cancel() + req_data.timeout = None + self._pending_requests.clear() + for timer in self._reply_msg_buffer.values(): + timer.cancel() + self._reply_msg_buffer.clear() + self._device_msg_matcher = MIoTMatcher() + self.__deinit_socket() + self._internal_loop.stop() def __init_socket(self) -> None: self.__deinit_socket() @@ -1145,7 +1166,7 @@ def __create_socket(self, if_name: str) -> None: sock.setsockopt( socket.SOL_SOCKET, socket.SO_BINDTODEVICE, if_name.encode()) sock.bind(('', self._local_port or 0)) - self._mev.set_read_handler( + self._internal_loop.add_reader( sock.fileno(), self.__socket_read_handler, (if_name, sock)) self._broadcast_socks[if_name] = sock self._local_port = self._local_port or sock.getsockname()[1] @@ -1163,7 +1184,7 @@ def __destroy_socket(self, if_name: str) -> None: sock = self._broadcast_socks.pop(if_name, None) if not sock: return - self._mev.set_read_handler(sock.fileno(), None, None) + self._internal_loop.remove_reader(sock.fileno()) sock.close() _LOGGER.info('destroyed socket, %s', if_name) @@ -1190,7 +1211,7 @@ def __raw_message_handler( return # Keep alive message did: str = str(struct.unpack('>Q', data[4:12])[0]) - device: MIoTLanDevice = self._lan_devices.get(did) + device: _MIoTLanDevice | None = self._lan_devices.get(did) if not device: return timestamp: int = struct.unpack('>I', data[12:16])[0] @@ -1230,11 +1251,15 @@ def __message_handler(self, did: str, msg: dict) -> None: _LOGGER.warning('invalid message, no id, %s, %s', did, msg) return # Reply - req: MIoTLanRequestData = self._pending_requests.pop(msg['id'], None) + req: _MIoTLanRequestData | None = \ + self._pending_requests.pop(msg['id'], None) if req: - self._mev.clear_timeout(req.timeout) - self._main_loop.call_soon_threadsafe( - req.handler, msg, req.handler_ctx) + if req.timeout: + req.timeout.cancel() + req.timeout = None + if req.handler is not None: + self._main_loop.call_soon_threadsafe( + req.handler, msg, req.handler_ctx) return # Handle up link message if 'method' not in msg or 'params' not in msg: @@ -1254,7 +1279,7 @@ def __message_handler(self, did: str, msg: dict) -> None: 'invalid message, no siid or piid, %s, %s', did, msg) continue key = f'{did}/p/{param["siid"]}/{param["piid"]}' - subs: list[MIoTLanRegisterBroadcastData] = list( + subs: list[_MIoTLanRegisterBroadcastData] = list( self._device_msg_matcher.iter_match(key)) for sub in subs: self._main_loop.call_soon_threadsafe( @@ -1265,7 +1290,7 @@ def __message_handler(self, did: str, msg: dict) -> None: and 'eiid' in msg['params'] ): key = f'{did}/e/{msg["params"]["siid"]}/{msg["params"]["eiid"]}' - subs: list[MIoTLanRegisterBroadcastData] = list( + subs: list[_MIoTLanRegisterBroadcastData] = list( self._device_msg_matcher.iter_match(key)) for sub in subs: self._main_loop.call_soon_threadsafe( @@ -1281,15 +1306,16 @@ def __filter_dup_message(self, did: str, msg_id: int) -> bool: filter_id = f'{did}.{msg_id}' if filter_id in self._reply_msg_buffer: return True - self._reply_msg_buffer[filter_id] = self._mev.set_timeout( - 5000, + self._reply_msg_buffer[filter_id] = self._internal_loop.call_later( + 5, lambda filter_id: self._reply_msg_buffer.pop(filter_id, None), filter_id) + return False def __sendto( - self, if_name: str, data: bytes, address: str, port: int + self, if_name: str | None, data: bytes, address: str, port: int ) -> None: - if address == '255.255.255.255': + if if_name is None: # Broadcast for if_n, sock in self._broadcast_socks.items(): _LOGGER.debug('send broadcast, %s', if_n) @@ -1302,58 +1328,25 @@ def __sendto( return sock.sendto(data, socket.MSG_DONTWAIT, (address, port)) - def __scan_devices(self, ctx: any) -> None: + def __scan_devices(self) -> None: if self._scan_timer: - self._mev.clear_timeout(self._scan_timer) - # Scan devices - self.ping(if_name=None, target_ip='255.255.255.255') + self._scan_timer.cancel() + self._scan_timer = None + try: + # Scan devices + self.ping(if_name=None, target_ip='255.255.255.255') + except Exception as err: # pylint: disable=broad-exception-caught + # Ignore any exceptions to avoid blocking the loop + _LOGGER.error('ping device error, %s', err) + pass scan_time = self.__get_next_scan_time() - self._scan_timer = self._mev.set_timeout( - scan_time, self.__scan_devices, None) - _LOGGER.debug('next scan time: %sms', scan_time) + self._scan_timer = self._internal_loop.call_later( + scan_time, self.__scan_devices) + _LOGGER.debug('next scan time: %ss', scan_time) - def __get_next_scan_time(self) -> int: + def __get_next_scan_time(self) -> float: if not self._last_scan_interval: self._last_scan_interval = self.OT_PROBE_INTERVAL_MIN self._last_scan_interval = min( self._last_scan_interval*2, self.OT_PROBE_INTERVAL_MAX) return self._last_scan_interval - - async def __on_network_info_change( - self, - status: InterfaceStatus, - info: NetworkInfo - ) -> None: - _LOGGER.info( - 'on network info change, status: %s, info: %s', status, info) - available_net_ifs = set() - for if_name in list(self._network.network_info.keys()): - available_net_ifs.add(if_name) - if len(available_net_ifs) == 0: - await self.deinit_async() - self._available_net_ifs = available_net_ifs - return - if self._net_ifs.isdisjoint(available_net_ifs): - _LOGGER.info('no valid net_ifs') - await self.deinit_async() - self._available_net_ifs = available_net_ifs - return - if not self._init_done: - self._available_net_ifs = available_net_ifs - await self.init_async() - return - self.__lan_send_cmd( - MIoTLanCmdType.NET_INFO_UPDATE, MIoTLanNetworkUpdateData( - status=status, if_name=info.name)) - - async def __on_mips_service_change( - self, group_id: str, state: MipsServiceState, data: dict - ) -> None: - _LOGGER.info( - 'on mips service change, %s, %s, %s', group_id, state, data) - if len(self._mips_service.get_services()) > 0: - _LOGGER.info('find central service, deinit miot lan') - await self.deinit_async() - else: - _LOGGER.info('no central service, init miot lan') - await self.init_async() diff --git a/custom_components/xiaomi_home/miot/miot_network.py b/custom_components/xiaomi_home/miot/miot_network.py index a4606eb..160d660 100644 --- a/custom_components/xiaomi_home/miot/miot_network.py +++ b/custom_components/xiaomi_home/miot/miot_network.py @@ -52,7 +52,7 @@ from dataclasses import dataclass from enum import Enum, auto import subprocess -from typing import Callable, Optional +from typing import Callable, Coroutine, Optional import psutil import ipaddress @@ -97,7 +97,7 @@ class MIoTNetwork: _sub_list_network_status: dict[str, Callable[[bool], asyncio.Future]] _sub_list_network_info: dict[str, Callable[[ - InterfaceStatus, NetworkInfo], asyncio.Future]] + InterfaceStatus, NetworkInfo], Coroutine]] _ping_address_priority: int @@ -155,7 +155,7 @@ def unsub_network_status(self, key: str) -> None: def sub_network_info( self, key: str, - handler: Callable[[InterfaceStatus, NetworkInfo], asyncio.Future] + handler: Callable[[InterfaceStatus, NetworkInfo], Coroutine] ) -> None: self._sub_list_network_info[key] = handler