|
44 | 44 | import ssl |
45 | 45 | import uuid |
46 | 46 | import warnings |
| 47 | +import re |
| 48 | +import time |
| 49 | +import signal |
| 50 | +import select |
| 51 | +import sys |
47 | 52 | from collections import deque |
48 | 53 | from struct import unpack |
49 | 54 |
|
@@ -303,6 +308,13 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: |
303 | 308 | self.address_list = _AddressList(self.options['host'], self.options['port'], |
304 | 309 | self.options['backup_server_node'], self._logger) |
305 | 310 |
|
| 311 | + # TOTP support |
| 312 | + self.totp = self.options.get('totp') |
| 313 | + if self.totp is not None: |
| 314 | + if not isinstance(self.totp, str): |
| 315 | + raise TypeError('The value of connection option "totp" should be a string') |
| 316 | + self._logger.info('TOTP received in connection options') |
| 317 | + |
306 | 318 | # OAuth authentication setup |
307 | 319 | self.options.setdefault('oauth_access_token', DEFAULT_OAUTH_ACCESS_TOKEN) |
308 | 320 | if not isinstance(self.options['oauth_access_token'], str): |
@@ -918,16 +930,112 @@ def startup_connection(self) -> None: |
918 | 930 | else: |
919 | 931 | auth_category = '' |
920 | 932 |
|
921 | | - self.write(messages.Startup(user, database, session_label, os_user_name, autocommit, binary_transfer, |
922 | | - request_complex_types, oauth_access_token, workload, auth_category)) |
| 933 | + # Check if user has provided TOTP in options |
| 934 | + totp = self.options.get("totp", None) |
| 935 | + retried_totp = False |
| 936 | + |
| 937 | + def send_startup(totp_value=None): |
| 938 | + self.write(messages.Startup( |
| 939 | + user, database, session_label, os_user_name, |
| 940 | + autocommit, binary_transfer, request_complex_types, |
| 941 | + oauth_access_token, workload, auth_category, |
| 942 | + totp_value |
| 943 | + )) |
| 944 | + |
923 | 945 |
|
| 946 | + send_startup(totp_value=totp) # ✅ First attempt |
924 | 947 | while True: |
925 | 948 | message = self.read_message() |
926 | | - |
| 949 | + self._logger.debug(f"Received message: {type(message).__name__}") |
| 950 | + self._logger.debug(f"Message code: {getattr(message, 'code', None)}") |
927 | 951 | if isinstance(message, messages.Authentication): |
928 | 952 | if message.code == messages.Authentication.OK: |
929 | 953 | self._logger.info("User {} successfully authenticated" |
930 | 954 | .format(self.options['user'])) |
| 955 | + # 🔁 Continue reading messages after successful authentication |
| 956 | + while True: |
| 957 | + message = self.read_message() |
| 958 | + self._logger.debug(f"Post-auth message: {type(message).__name__}") |
| 959 | + if isinstance(message, messages.ReadyForQuery): |
| 960 | + self.transaction_status = message.transaction_status |
| 961 | + # self.session_id = message.session_id |
| 962 | + self._logger.info("Connection is ready") |
| 963 | + break |
| 964 | + elif isinstance(message, messages.ParameterStatus): |
| 965 | + self.parameters[message.key] = message.value |
| 966 | + elif isinstance(message, messages.BackendKeyData): |
| 967 | + self.backend_pid = message.pid |
| 968 | + self.backend_key = message.key |
| 969 | + elif isinstance(message, messages.ErrorResponse): |
| 970 | + error_msg = message.error_message() |
| 971 | + |
| 972 | + # Extract only the "Message: ..." part |
| 973 | + match = re.search(r'Message: (.+?)(?:, Sqlstate|$)', error_msg, re.DOTALL) |
| 974 | + short_msg = match.group(1).strip() if match else error_msg.strip() |
| 975 | + |
| 976 | + if "Invalid TOTP" in short_msg: |
| 977 | + print("Authentication failed: Invalid TOTP token.") |
| 978 | + self._logger.error("Authentication failed: Invalid TOTP token.") |
| 979 | + self.close_socket() |
| 980 | + raise errors.ConnectionError("Authentication failed: Invalid TOTP token.") |
| 981 | + |
| 982 | + # Generic error fallback |
| 983 | + print(f"Authentication failed: {short_msg}") |
| 984 | + self._logger.error(short_msg) |
| 985 | + raise errors.ConnectionError(f"Authentication failed: {short_msg}") |
| 986 | + else: |
| 987 | + self._logger.warning(f"Unexpected message type: {type(message).__name__}") |
| 988 | + |
| 989 | + break |
| 990 | + elif message.code == messages.Authentication.TOTP: |
| 991 | + if retried_totp: |
| 992 | + raise errors.ConnectionError("TOTP authentication failed.") |
| 993 | + |
| 994 | + # ✅ If TOTP not provided initially, prompt only once |
| 995 | + if not totp: |
| 996 | + timeout_seconds = 30 # 5 minutes timeout |
| 997 | + try: |
| 998 | + print("Enter TOTP: ", end="", flush=True) |
| 999 | + ready, _, _ = select.select([sys.stdin], [], [], timeout_seconds) |
| 1000 | + if ready: |
| 1001 | + totp_input = sys.stdin.readline().strip() |
| 1002 | + |
| 1003 | + # ❌ Blank TOTP entered |
| 1004 | + if not totp_input: |
| 1005 | + self._logger.error("Invalid TOTP: Cannot be empty.") |
| 1006 | + raise errors.ConnectionError("Invalid TOTP: Cannot be empty.") |
| 1007 | + |
| 1008 | + # ❌ Validate TOTP format (must be 6 digits) |
| 1009 | + if not totp_input.isdigit() or len(totp_input) != 6: |
| 1010 | + print("Invalid TOTP format. Please enter a 6-digit code.") |
| 1011 | + self._logger.error("Invalid TOTP format entered.") |
| 1012 | + raise errors.ConnectionError("Invalid TOTP format: Must be a 6-digit number.") |
| 1013 | + # ✅ Valid TOTP — retry connection |
| 1014 | + totp = totp_input |
| 1015 | + self.close_socket() |
| 1016 | + self.socket = self.establish_socket_connection(self.address_list) |
| 1017 | + self._logger.info(f"Retrying with TOTP: '{totp}'") |
| 1018 | + |
| 1019 | + # ✅ Re-init required attributes |
| 1020 | + self.backend_pid = 0 |
| 1021 | + self.backend_key = 0 |
| 1022 | + self.transaction_status = None |
| 1023 | + self.session_id = None |
| 1024 | + |
| 1025 | + self._logger.debug("Startup message sent with TOTP.") |
| 1026 | + send_startup(totp_value=totp) |
| 1027 | + |
| 1028 | + else: |
| 1029 | + self._logger.error("Session timeout: No TOTP entered within time limit.") |
| 1030 | + self.close_socket() |
| 1031 | + raise errors.ConnectionError("Session timeout: No TOTP entered within time limit.") |
| 1032 | + except (KeyboardInterrupt, EOFError): |
| 1033 | + raise errors.ConnectionError("TOTP input cancelled.") |
| 1034 | + else: |
| 1035 | + raise errors.ConnectionError("TOTP was requested but not provided.") |
| 1036 | + retried_totp = True |
| 1037 | + continue |
| 1038 | + |
931 | 1039 | elif message.code == messages.Authentication.CHANGE_PASSWORD: |
932 | 1040 | msg = "The password for user {} has expired".format(self.options['user']) |
933 | 1041 | self._logger.error(msg) |
|
0 commit comments