From 40cce6dd9a521d5403dd96e0f02621459977929e Mon Sep 17 00:00:00 2001 From: DoronZ Date: Wed, 30 Oct 2024 00:53:46 +0200 Subject: [PATCH] maintenance: refactor and add docstrings --- pymobiledevice3/pair_records.py | 78 +++++++++- pymobiledevice3/service_connection.py | 203 +++++++++++++++++++++++--- 2 files changed, 262 insertions(+), 19 deletions(-) diff --git a/pymobiledevice3/pair_records.py b/pymobiledevice3/pair_records.py index ba86adfd..145b2b14 100644 --- a/pymobiledevice3/pair_records.py +++ b/pymobiledevice3/pair_records.py @@ -19,12 +19,31 @@ def generate_host_id(hostname: str = None) -> str: + """ + Generate a unique host ID based on the hostname. + + :param hostname: The hostname to use for generating the host ID. + If None, the current hostname is used. + :type hostname: str, optional + :return: The generated host ID. + :rtype: str + """ hostname = platform.node() if hostname is None else hostname host_id = uuid.uuid3(uuid.NAMESPACE_DNS, hostname) return str(host_id).upper() def get_usbmux_pairing_record(identifier: str, usbmux_address: Optional[str] = None): + """ + Retrieve the pairing record from usbmuxd. + + :param identifier: The identifier of the device. + :type identifier: str + :param usbmux_address: The address of the usbmuxd server. + :type usbmux_address: Optional[str], optional + :return: The pairing record if found, otherwise None. + :rtype: dict or None + """ with suppress(NotPairedError, MuxException): with usbmux.create_mux(usbmux_address=usbmux_address) as mux: if isinstance(mux, PlistMuxConnection): @@ -35,6 +54,14 @@ def get_usbmux_pairing_record(identifier: str, usbmux_address: Optional[str] = N def get_itunes_pairing_record(identifier: str) -> Optional[dict]: + """ + Retrieve the pairing record from iTunes. + + :param identifier: The identifier of the device. + :type identifier: str + :return: The pairing record if found, otherwise None. + :rtype: Optional[dict] + """ filename = OSUTILS.pair_record_path / f'{identifier}.plist' try: with open(filename, 'rb') as f: @@ -45,6 +72,16 @@ def get_itunes_pairing_record(identifier: str) -> Optional[dict]: def get_local_pairing_record(identifier: str, pairing_records_cache_folder: Path) -> Optional[dict]: + """ + Retrieve the pairing record from local storage. + + :param identifier: The identifier of the device. + :type identifier: str + :param pairing_records_cache_folder: The path to the local pairing records cache folder. + :type pairing_records_cache_folder: Path + :return: The pairing record if found, otherwise None. + :rtype: Optional[dict] + """ logger.debug('Looking for pymobiledevice3 pairing record') path = pairing_records_cache_folder / f'{identifier}.{PAIRING_RECORD_EXT}' if not path.exists(): @@ -56,12 +93,20 @@ def get_local_pairing_record(identifier: str, pairing_records_cache_folder: Path def get_preferred_pair_record(identifier: str, pairing_records_cache_folder: Path, usbmux_address: Optional[str] = None) -> dict: """ - look for an existing pair record to connected device by following order: + Look for an existing pair record for the connected device in the following order: - usbmuxd - iTunes - local storage - """ + :param identifier: The identifier of the device. + :type identifier: str + :param pairing_records_cache_folder: The path to the local pairing records cache folder. + :type pairing_records_cache_folder: Path + :param usbmux_address: The address of the usbmuxd server. + :type usbmux_address: Optional[str], optional + :return: The preferred pairing record. + :rtype: dict + """ # usbmuxd pair_record = get_usbmux_pairing_record(identifier=identifier, usbmux_address=usbmux_address) if pair_record is not None: @@ -77,6 +122,15 @@ def get_preferred_pair_record(identifier: str, pairing_records_cache_folder: Pat def create_pairing_records_cache_folder(pairing_records_cache_folder: Path = None) -> Path: + """ + Create the pairing records cache folder if it does not exist. + + :param pairing_records_cache_folder: The path to the local pairing records cache folder. + If None, the home folder is used. + :type pairing_records_cache_folder: Path, optional + :return: The path to the pairing records cache folder. + :rtype: Path + """ if pairing_records_cache_folder is None: pairing_records_cache_folder = get_home_folder() else: @@ -86,13 +140,33 @@ def create_pairing_records_cache_folder(pairing_records_cache_folder: Path = Non def get_remote_pairing_record_filename(identifier: str) -> str: + """ + Generate the filename for the remote pairing record. + + :param identifier: The identifier of the device. + :type identifier: str + :return: The filename for the remote pairing record. + :rtype: str + """ return f'remote_{identifier}' def iter_remote_pair_records() -> Generator[Path, None, None]: + """ + Iterate over the remote pairing records in the home folder. + + :return: A generator yielding paths to the remote pairing records. + :rtype: Generator[Path, None, None] + """ return get_home_folder().glob('remote_*') def iter_remote_paired_identifiers() -> Generator[str, None, None]: + """ + Iterate over the identifiers of the remote paired devices. + + :return: A generator yielding the identifiers of the remote paired devices. + :rtype: Generator[str, None, None] + """ for file in iter_remote_pair_records(): yield file.parts[-1].split('remote_', 1)[1].split('.', 1)[0] diff --git a/pymobiledevice3/service_connection.py b/pymobiledevice3/service_connection.py index 6d3cc7e9..cd143c0d 100755 --- a/pymobiledevice3/service_connection.py +++ b/pymobiledevice3/service_connection.py @@ -39,19 +39,41 @@ def build_plist(d: dict, endianity: str = '>', fmt: Enum = plistlib.FMT_XML) -> bytes: + """ + Convert a dictionary to a plist-formatted byte string prefixed with a length field. + + :param d: The dictionary to convert. + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :param fmt: The plist format (e.g., plistlib.FMT_XML). + :return: The plist-formatted byte string. + """ payload = plistlib.dumps(d, fmt=fmt) message = struct.pack(endianity + 'L', len(payload)) return message + payload -def parse_plist(payload): +def parse_plist(payload: bytes) -> dict: + """ + Parse a plist-formatted byte string into a dictionary. + + :param payload: The plist-formatted byte string to parse. + :return: The parsed dictionary. + :raises PyMobileDevice3Exception: If the payload is invalid. + """ try: return plistlib.loads(payload) except plistlib.InvalidFileException: raise PyMobileDevice3Exception(f'parse_plist invalid data: {payload[:100].hex()}') -def create_context(certfile, keyfile=None): +def create_context(certfile: str, keyfile: Optional[str] = None) -> ssl.SSLContext: + """ + Create an SSL context for a secure connection. + + :param certfile: The path to the certificate file. + :param keyfile: The path to the key file (optional). + :return: An SSL context object. + """ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) if ssl.OPENSSL_VERSION.lower().startswith('openssl'): context.set_ciphers('ALL:!aNULL:!eNULL:@SECLEVEL=0') @@ -68,6 +90,12 @@ class ServiceConnection: """ wrapper for tcp-relay connections """ def __init__(self, sock: socket.socket, mux_device: MuxDevice = None): + """ + Initialize a ServiceConnection object. + + :param sock: The socket to use for the connection. + :param mux_device: The MuxDevice associated with the connection (optional). + """ self.logger = logging.getLogger(__name__) self.socket = sock self._offset = 0 @@ -81,6 +109,15 @@ def __init__(self, sock: socket.socket, mux_device: MuxDevice = None): @staticmethod def create_using_tcp(hostname: str, port: int, keep_alive: bool = True, create_connection_timeout: int = DEFAULT_TIMEOUT) -> 'ServiceConnection': + """ + Create a ServiceConnection using a TCP connection. + + :param hostname: The hostname of the server to connect to. + :param port: The port to connect to. + :param keep_alive: Whether to enable TCP keep-alive. + :param create_connection_timeout: The timeout for creating the connection. + :return: A ServiceConnection object. + """ sock = socket.create_connection((hostname, port), timeout=create_connection_timeout) sock.settimeout(None) if keep_alive: @@ -90,6 +127,17 @@ def create_using_tcp(hostname: str, port: int, keep_alive: bool = True, @staticmethod def create_using_usbmux(udid: Optional[str], port: int, connection_type: str = None, usbmux_address: Optional[str] = None) -> 'ServiceConnection': + """ + Create a ServiceConnection using a USBMux connection. + + :param udid: The UDID of the target device. + :param port: The port to connect to. + :param connection_type: The type of connection to use. + :param usbmux_address: The address of the usbmuxd socket. + :return: A ServiceConnection object. + :raises DeviceNotFoundError: If the device with the specified UDID is not found. + :raises NoDeviceConnectedError: If no device is connected. + """ target_device = select_device(udid, connection_type=connection_type, usbmux_address=usbmux_address) if target_device is None: if udid: @@ -99,12 +147,19 @@ def create_using_usbmux(udid: Optional[str], port: int, connection_type: str = N return ServiceConnection(sock, mux_device=target_device) def setblocking(self, blocking: bool) -> None: + """ + Set the blocking mode of the socket. + + :param blocking: If True, set the socket to blocking mode; otherwise, set it to non-blocking mode. + """ self.socket.setblocking(blocking) def close(self) -> None: + """ Close the connection. """ self.socket.close() async def aio_close(self) -> None: + """ Asynchronously close the connection. """ if self.writer is None: return self.writer.close() @@ -115,25 +170,59 @@ async def aio_close(self) -> None: self.writer = None self.reader = None - def recv(self, length=4096) -> bytes: - """ socket.recv() normal behavior. attempt to receive a single chunk """ + def recv(self, length: int = 4096) -> bytes: + """ + Receive data from the socket. + + :param length: The maximum amount of data to receive. + :return: The received data. + """ return self.socket.recv(length) def sendall(self, data: bytes) -> None: + """ + Send data to the socket. + + :param data: The data to send. + :raises ConnectionTerminatedError: If the connection is terminated abruptly. + """ try: self.socket.sendall(data) except ssl.SSLEOFError as e: raise ConnectionTerminatedError from e - def send_recv_plist(self, data: dict, endianity='>', fmt=plistlib.FMT_XML) -> Any: + def send_recv_plist(self, data: dict, endianity: str = '>', fmt: Enum = plistlib.FMT_XML) -> Any: + """ + Send a plist to the socket and receive a plist response. + + :param data: The dictionary to send as a plist. + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :param fmt: The plist format (e.g., plistlib.FMT_XML). + :return: The received plist as a dictionary. + """ self.send_plist(data, endianity=endianity, fmt=fmt) return self.recv_plist(endianity=endianity) - async def aio_send_recv_plist(self, data: dict, endianity='>', fmt=plistlib.FMT_XML) -> Any: + async def aio_send_recv_plist(self, data: dict, endianity: str = '>', fmt: Enum = plistlib.FMT_XML) -> Any: + """ + Asynchronously send a plist to the socket and receive a plist response. + + :param data: The dictionary to send as a plist. + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :param fmt: The plist format (e.g., plistlib.FMT_XML). + :return: The received plist as a dictionary. + """ await self.aio_send_plist(data, endianity=endianity, fmt=fmt) return await self.aio_recv_plist(endianity=endianity) def recvall(self, size: int) -> bytes: + """ + Receive all data of a specified size from the socket. + + :param size: The amount of data to receive. + :return: The received data. + :raises ConnectionAbortedError: If the connection is aborted. + """ data = b'' while len(data) < size: chunk = self.recv(size - len(data)) @@ -142,8 +231,13 @@ def recvall(self, size: int) -> bytes: data += chunk return data - def recv_prefixed(self, endianity='>') -> bytes: - """ receive a data block prefixed with a u32 length field """ + def recv_prefixed(self, endianity: str = '>') -> bytes: + """ + Receive a data block prefixed with a length field. + + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :return: The received data block. + """ size = self.recvall(4) if not size or len(size) != 4: return b'' @@ -156,43 +250,100 @@ def recv_prefixed(self, endianity='>') -> bytes: time.sleep(0) async def aio_recvall(self, size: int) -> bytes: - """ receive a payload """ + """ + Asynchronously receive data of a specified size from the socket. + + :param size: The amount of data to receive. + :return: The received data. + """ return await self.reader.readexactly(size) - async def aio_recv_prefixed(self, endianity='>') -> bytes: - """ receive a data block prefixed with a u32 length field """ + async def aio_recv_prefixed(self, endianity: str = '>') -> bytes: + """ + Asynchronously receive a data block prefixed with a length field. + + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :return: The received data block. + """ size = await self.aio_recvall(4) size = struct.unpack(endianity + 'L', size)[0] return await self.aio_recvall(size) def send_prefixed(self, data: bytes) -> None: - """ send a data block prefixed with a u32 length field """ + """ + Send a data block prefixed with a length field. + + :param data: The data to send. + """ if isinstance(data, str): data = data.encode() hdr = struct.pack('>L', len(data)) msg = b''.join([hdr, data]) return self.sendall(msg) - def recv_plist(self, endianity='>') -> dict: + def recv_plist(self, endianity: str = '>') -> dict: + """ + Receive a plist from the socket and parse it into a dictionary. + + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :return: The received plist as a dictionary. + """ return parse_plist(self.recv_prefixed(endianity=endianity)) - async def aio_recv_plist(self, endianity='>') -> dict: + async def aio_recv_plist(self, endianity: str = '>') -> dict: + """ + Asynchronously receive a plist from the socket and parse it into a dictionary. + + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :return: The received plist as a dictionary. + """ return parse_plist(await self.aio_recv_prefixed(endianity)) - def send_plist(self, d, endianity='>', fmt=plistlib.FMT_XML) -> None: + def send_plist(self, d: dict, endianity: str = '>', fmt: Enum = plistlib.FMT_XML) -> None: + """ + Send a dictionary as a plist to the socket. + + :param d: The dictionary to send. + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :param fmt: The plist format (e.g., plistlib.FMT_XML). + """ return self.sendall(build_plist(d, endianity, fmt)) async def aio_sendall(self, payload: bytes) -> None: + """ + Asynchronously send data to the socket. + + :param payload: The data to send. + """ self.writer.write(payload) await self.writer.drain() async def aio_send_plist(self, d: dict, endianity: str = '>', fmt: Enum = plistlib.FMT_XML) -> None: + """ + Asynchronously send a dictionary as a plist to the socket. + + :param d: The dictionary to send. + :param endianity: The byte order ('>' for big-endian, '<' for little-endian). + :param fmt: The plist format (e.g., plistlib.FMT_XML). + """ await self.aio_sendall(build_plist(d, endianity, fmt)) - def ssl_start(self, certfile, keyfile=None) -> None: + def ssl_start(self, certfile: str, keyfile: Optional[str] = None) -> None: + """ + Start an SSL connection. + + :param certfile: The path to the certificate file. + :param keyfile: The path to the key file (optional). + """ self.socket = create_context(certfile, keyfile=keyfile).wrap_socket(self.socket) - async def aio_ssl_start(self, certfile, keyfile=None) -> None: + async def aio_ssl_start(self, certfile: str, keyfile: Optional[str] = None) -> None: + """ + Asynchronously start an SSL connection. + + :param certfile: The path to the certificate file. + :param keyfile: The path to the key file (optional). + """ self.reader, self.writer = await asyncio.open_connection( sock=self.socket, ssl=create_context(certfile, keyfile=keyfile), @@ -200,9 +351,11 @@ async def aio_ssl_start(self, certfile, keyfile=None) -> None: ) async def aio_start(self) -> None: + """ Asynchronously start a connection. """ self.reader, self.writer = await asyncio.open_connection(sock=self.socket) def shell(self) -> None: + """ Start an interactive shell. """ IPython.embed( header=highlight(SHELL_USAGE, lexers.PythonLexer(), formatters.Terminal256Formatter(style='native')), user_ns={ @@ -210,13 +363,29 @@ def shell(self) -> None: }) def read(self, size: int) -> bytes: + """ + Read data from the socket. + + :param size: The amount of data to read. + :return: The read data. + """ result = self.recvall(size) self._offset += size return result def write(self, data: bytes) -> None: + """ + Write data to the socket. + + :param data: The data to write. + """ self.sendall(data) self._offset += len(data) def tell(self) -> int: + """ + Get the current offset. + + :return: The current offset. + """ return self._offset