From ae93feb3898b74f31f42061581a47d30a8cf7a2f Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Wed, 7 Feb 2024 22:11:46 +0800 Subject: [PATCH] OAuth token refresh implementation (#5) --- .github/workflows/ci.yaml | 27 ++-- README.md | 52 +++++++- requirements-dev.txt | 1 + setup.py | 3 +- vertica_python/__init__.py | 4 +- vertica_python/errors.py | 8 ++ vertica_python/tests/common/base.py | 5 + .../tests/common/vp_test.conf.example | 5 + .../tests/integration_tests/base.py | 10 +- .../integration_tests/test_authentication.py | 73 ++++++++++- .../tests/unit_tests/test_parsedsn.py | 2 + vertica_python/vertica/connection.py | 63 +++++++-- .../backend_messages/authentication.py | 25 +++- vertica_python/vertica/oauth_manager.py | 123 ++++++++++++++++++ 14 files changed, 367 insertions(+), 34 deletions(-) create mode 100644 vertica_python/vertica/oauth_manager.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 770d89d1..b903c697 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -8,13 +8,20 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.9'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.10'] + + env: + REALM: test + USER: oauth_user + PASSWORD: password + CLIENT_ID: vertica + CLIENT_SECRET: P9f8350QQIUhFfK1GF5sMhq4Dm3P6Sbs steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Set up a Keycloak docker container @@ -47,12 +54,6 @@ jobs: echo "Wait for keycloak ready ..." bash -c 'while true; do curl -s localhost:8080 &>/dev/null; ret=$?; [[ $ret -eq 0 ]] && break; echo "..."; sleep 3; done' - REALM="test" - USER="oauth_user" - PASSWORD="password" - CLIENT_ID="vertica" - CLIENT_SECRET="P9f8350QQIUhFfK1GF5sMhq4Dm3P6Sbs" - docker exec -i keycloak /bin/bash < access_token.txt + cat oauth.json | python3 -c 'import json,sys;obj=json.load(sys.stdin);print(obj["refresh_token"])' > refresh_token.txt docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "CREATE AUTHENTICATION v_oauth METHOD 'oauth' HOST '0.0.0.0/0';" docker exec -u dbadmin vertica_docker /opt/vertica/bin/vsql -c "ALTER AUTHENTICATION v_oauth SET client_id = '${CLIENT_ID}';" @@ -95,5 +97,10 @@ jobs: run: | export VP_TEST_USER=dbadmin export VP_TEST_OAUTH_ACCESS_TOKEN=`cat access_token.txt` - export VP_TEST_OAUTH_USER=oauth_user + export VP_TEST_OAUTH_REFRESH_TOKEN=`cat refresh_token.txt` + export VP_TEST_OAUTH_USER=${USER} + export VP_TEST_OAUTH_CLIENT_ID=${CLIENT_ID} + export VP_TEST_OAUTH_CLIENT_SECRET=${CLIENT_SECRET} + export VP_TEST_OAUTH_TOKEN_URL="http://`hostname`:8080/realms/${REALM}/protocol/openid-connect/token" + export VP_TEST_OAUTH_DISCOVERY_URL="http://`hostname`:8080/realms/${REALM}/.well-known/openid-configuration" tox -e py diff --git a/README.md b/README.md index 14d57ede..85c88fd9 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ with vertica_python.connect(**conn_info) as connection: | ------------- | ------------- | | host | The server host of the connection. This can be a host name or an IP address.
**_Default_**: "localhost" | | port | The port of the connection.
**_Default_**: 5433 | -| user | The database user name to use to connect to the database.
**_Default_**: OS login user name | +| user | The database user name to use to connect to the database.
**_Default_**:
    (for non-OAuth connections) OS login user name
    (for OAuth connections) "" | | password | The password to use to log into the database.
**_Default_**: "" | | database | The database name.
**_Default_**: "" | | autocommit | See [Autocommit](#autocommit).
**_Default_**: False | @@ -103,7 +103,9 @@ with vertica_python.connect(**conn_info) as connection: | kerberos_service_name | See [Kerberos Authentication](#kerberos-authentication).
**_Default_**: "vertica" | | log_level | See [Logging](#logging). | | log_path | See [Logging](#logging). | -| oauth_access_token | To authenticate via OAuth, provide an OAuth Access Token that authorizes a user to the database.
**_Default_**: "" | +| oauth_access_token | See [OAuth Authentication](#oauth-authentication).
**_Default_**: "" | +| oauth_refresh_token | See [OAuth Authentication](#oauth-authentication).
**_Default_**: "" | +| oauth_config | See [OAuth Authentication](#oauth-authentication).
**_Default_**: {} | | request_complex_types | See [SQL Data conversion to Python objects](#sql-data-conversion-to-python-objects).
**_Default_**: True | | session_label | Sets a label for the connection on the server. This value appears in the client_label column of the _v_monitor.sessions_ system table.
**_Default_**: an auto-generated label with format of `vertica-python-{version}-{random_uuid}` | | ssl | See [TLS/SSL](#tlsssl).
**_Default_**: False (disabled) | @@ -141,7 +143,7 @@ with vertica_python.connect(dsn=connection_str, **additional_info) as conn: ``` #### TLS/SSL -You can pass `True` to `ssl` to enable TLS/SSL connection (Internally [ssl.wrap_socket(sock)](https://docs.python.org/3/library/ssl.html#ssl.wrap_socket) is called). +You can pass `True` to `ssl` to enable TLS/SSL connection (equivalent to TLSMode=require). ```python import vertica_python @@ -258,6 +260,50 @@ with vertica_python.connect(**conn_info) as conn: # do things ``` +#### OAuth Authentication +To authenticate via OAuth, one way is to provide an `oauth_access_token` that authorizes a user to the database. +```python +import vertica_python + +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'database': 'a_database', + # valid OAuth access token + 'oauth_access_token': 'xxxxxx'} + +with vertica_python.connect(**conn_info) as conn: + # do things +``` +In cases where `oauth_access_token` is not set or introspection fails (e.g. when the access token expires), the client can do a token refresh when both `oauth_refresh_token` and `oauth_config` are set. The client will retrieve a new access token from the identity provider and use it to connect with the database. +```python +import vertica_python + +conn_info = {'host': '127.0.0.1', + 'port': 5433, + 'database': 'a_database', + # OAuth refresh token and configurations + 'oauth_refresh_token': 'xxxxxx', + 'oauth_config': { + 'client_secret': 'wK3SqFbExDS', + 'client_id': 'vertica', + 'token_url': 'https://example.com:8443/realms/master/protocol/openid-connect/token', + } +} + +with vertica_python.connect(**conn_info) as conn: + # do things +``` +The following table lists the `oauth_config` parameters used to configure OAuth token refresh: + +| Parameter | Description | +| ------------- | ------------- | +| client_id | The client ID of the client application registered in the identity provider. | +| client_secret | The client secret of the client application registered in the identity provider.| +| token_url | The endpoint to which token refresh requests are sent. The format for this depends on your provider. For examples, see the [Keycloak](https://www.keycloak.org/docs/latest/securing_apps/#token-endpoint) and [Okta](https://developer.okta.com/docs/reference/api/oidc/#token) documentation.| +| discovery_url | Also known as the [OpenID Provider Configuration Document](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest), this endpoint contains a list of all other endpoints supported by the identity provider. If set, *token_url* do not need to be specified.
If you set both *discovery_url* and *token_url*, then *token_url* takes precedence.| +| scope | The requested OAuth scopes, delimited with spaces. These scopes define the extent of access to the resource server (in this case, Vertica) granted to the client by the access token. For details, see the [OAuth documentation](https://www.oauth.com/oauth2-servers/scope/defining-scopes/). | + + #### Logging Logging is disabled by default if neither ```log_level``` or ```log_path``` are set. Passing value to at least one of those options to enable logging. diff --git a/requirements-dev.txt b/requirements-dev.txt index 64507bf0..29741ea4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,5 +2,6 @@ pytest pytest-timeout python-dateutil six +requests tox #kerberos diff --git a/setup.py b/setup.py index b013717e..52f1c443 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,8 @@ python_requires=">=3.7", install_requires=[ 'python-dateutil>=1.5', - 'six>=1.10.0' + 'six>=1.10.0', + 'requests>=2.26.0' ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/vertica_python/__init__.py b/vertica_python/__init__.py index 80240163..55813e18 100644 --- a/vertica_python/__init__.py +++ b/vertica_python/__init__.py @@ -59,8 +59,8 @@ version_info = (1, 3, 8) __version__ = '.'.join(map(str, version_info)) -# The protocol version (3.15) implemented in this library. -PROTOCOL_VERSION = 3 << 16 | 15 +# The protocol version (3.16) implemented in this library. +PROTOCOL_VERSION = 3 << 16 | 16 apilevel = 2.0 threadsafety = 1 # Threads may share the module, but not connections! diff --git a/vertica_python/errors.py b/vertica_python/errors.py index 26bc3f64..ff83faa1 100644 --- a/vertica_python/errors.py +++ b/vertica_python/errors.py @@ -103,6 +103,14 @@ class KerberosError(ConnectionError): class SSLNotSupported(ConnectionError): pass +class OAuthConfigurationError(ConnectionError): + pass + +class OAuthEndpointDiscoveryError(ConnectionError): + pass + +class OAuthTokenRefreshError(ConnectionError): + pass class MessageError(InternalError): pass diff --git a/vertica_python/tests/common/base.py b/vertica_python/tests/common/base.py index 11d8ec43..8b4460d8 100644 --- a/vertica_python/tests/common/base.py +++ b/vertica_python/tests/common/base.py @@ -56,6 +56,11 @@ 'password': '', 'database': '', 'oauth_access_token': '', + 'oauth_refresh_token': '', + 'oauth_client_id': '', + 'oauth_client_secret': '', + 'oauth_token_url': '', + 'oauth_discovery_url': '', 'oauth_user': '', } diff --git a/vertica_python/tests/common/vp_test.conf.example b/vertica_python/tests/common/vp_test.conf.example index 9ca6d174..9181de3c 100644 --- a/vertica_python/tests/common/vp_test.conf.example +++ b/vertica_python/tests/common/vp_test.conf.example @@ -18,4 +18,9 @@ VP_TEST_LOG_DIR=mylog/vp_tox_tests_log # OAuth authentication information #VP_TEST_OAUTH_USER= #VP_TEST_OAUTH_ACCESS_TOKEN=****** +#VP_TEST_OAUTH_REFRESH_TOKEN=****** +#VP_TEST_OAUTH_CLIENT_ID=vertica +#VP_TEST_OAUTH_CLIENT_SECRET=****** +#VP_TEST_OAUTH_TOKEN_URL=http://hostname:8080/realms/test/protocol/openid-connect/token +#VP_TEST_OAUTH_DISCOVERY_URL=http://hostname:8080/realms/test/.well-known/openid-configuration diff --git a/vertica_python/tests/integration_tests/base.py b/vertica_python/tests/integration_tests/base.py index 64067ec4..022f223d 100644 --- a/vertica_python/tests/integration_tests/base.py +++ b/vertica_python/tests/integration_tests/base.py @@ -55,7 +55,10 @@ class VerticaPythonIntegrationTestCase(VerticaPythonTestCase): def setUpClass(cls): config_list = ['log_dir', 'log_level', 'host', 'port', 'user', 'password', 'database', - 'oauth_access_token', 'oauth_user',] + 'oauth_access_token', 'oauth_refresh_token', + 'oauth_client_secret', 'oauth_client_id', + 'oauth_token_url', 'oauth_discovery_url', + 'oauth_user',] cls.test_config = cls._load_test_config(config_list) # Test logger @@ -76,6 +79,11 @@ def setUpClass(cls): } cls._oauth_info = { 'access_token': cls.test_config['oauth_access_token'], + 'refresh_token': cls.test_config['oauth_refresh_token'], + 'client_secret': cls.test_config['oauth_client_secret'], + 'client_id': cls.test_config['oauth_client_id'], + 'token_url': cls.test_config['oauth_token_url'], + 'discovery_url': cls.test_config['oauth_discovery_url'], 'user': cls.test_config['oauth_user'], } cls.db_node_num = cls._get_node_num() diff --git a/vertica_python/tests/integration_tests/test_authentication.py b/vertica_python/tests/integration_tests/test_authentication.py index 5df8f9a6..1ecfc242 100644 --- a/vertica_python/tests/integration_tests/test_authentication.py +++ b/vertica_python/tests/integration_tests/test_authentication.py @@ -15,6 +15,7 @@ from __future__ import print_function, division, absolute_import, annotations from .base import VerticaPythonIntegrationTestCase +from ...errors import OAuthTokenRefreshError class AuthenticationTestCase(VerticaPythonIntegrationTestCase): @@ -28,6 +29,10 @@ def tearDown(self): self._conn_info['password'] = self._password if 'oauth_access_token' in self._conn_info: del self._conn_info['oauth_access_token'] + if 'oauth_refresh_token' in self._conn_info: + del self._conn_info['oauth_refresh_token'] + if 'oauth_config' in self._conn_info: + del self._conn_info['oauth_config'] super(AuthenticationTestCase, self).tearDown() def test_SHA512(self): @@ -109,10 +114,12 @@ def test_password_expire(self): cur.execute("DROP AUTHENTICATION IF EXISTS testIPv6hostHash CASCADE") cur.execute("DROP AUTHENTICATION IF EXISTS testlocalHash CASCADE") - def test_oauth(self): + def test_oauth_access_token(self): self.require_protocol_at_least(3 << 16 | 11) if not self._oauth_info['access_token']: - self.skipTest('OAuth not set') + self.skipTest('OAuth access token not set') + if not self._oauth_info['user'] and not self._conn_info['database']: + self.skipTest('Both database and oauth_user are not set') self._conn_info['user'] = self._oauth_info['user'] self._conn_info['oauth_access_token'] = self._oauth_info['access_token'] @@ -122,5 +129,67 @@ def test_oauth(self): res = cur.fetchone() self.assertEqual(res[0], 'OAuth') + def _test_oauth_refresh(self, access_token): + self.require_protocol_at_least(3 << 16 | 11) + if not self._oauth_info['refresh_token']: + self.skipTest('OAuth refresh token not set') + if not (self._oauth_info['client_secret'] and self._oauth_info['client_id'] and self._oauth_info['token_url']): + self.skipTest('One or more OAuth config (client_id, client_secret, token_url) not set') + if not self._oauth_info['user'] and not self._conn_info['database']: + self.skipTest('Both database and oauth_user are not set') + + if access_token is not None: + self._conn_info['oauth_access_token'] = access_token + self._conn_info['user'] = self._oauth_info['user'] + self._conn_info['oauth_refresh_token'] = self._oauth_info['refresh_token'] + self._conn_info['oauth_config'] = { + 'client_secret': self._oauth_info['client_secret'], + 'client_id': self._oauth_info['client_id'], + 'token_url': self._oauth_info['token_url'], + } + with self._connect() as conn: + cur = conn.cursor() + cur.execute("SELECT authentication_method FROM sessions WHERE session_id=(SELECT current_session())") + res = cur.fetchone() + self.assertEqual(res[0], 'OAuth') + + def test_oauth_token_refresh_with_access_token_not_set(self): + self._test_oauth_refresh(access_token=None) + + def test_oauth_token_refresh_with_invalid_access_token(self): + self._test_oauth_refresh(access_token='invalid_value') + + def test_oauth_token_refresh_with_empty_access_token(self): + self._test_oauth_refresh(access_token='') + + def test_oauth_token_refresh_with_discovery_url(self): + self.require_protocol_at_least(3 << 16 | 11) + if not self._oauth_info['refresh_token']: + self.skipTest('OAuth refresh token not set') + if not (self._oauth_info['client_secret'] and self._oauth_info['client_id'] and self._oauth_info['discovery_url']): + self.skipTest('One or more OAuth config (client_id, client_secret, discovery_url) not set') + if not self._oauth_info['user'] and not self._conn_info['database']: + self.skipTest('Both database and oauth_user are not set') + + self._conn_info['user'] = self._oauth_info['user'] + self._conn_info['oauth_refresh_token'] = self._oauth_info['refresh_token'] + msg = 'Token URL or Discovery URL must be set.' + self.assertConnectionFail(err_type=OAuthTokenRefreshError, err_msg=msg) + + self._conn_info['oauth_config'] = { + 'client_secret': self._oauth_info['client_secret'], + 'client_id': self._oauth_info['client_id'], + 'discovery_url': self._oauth_info['discovery_url'], + } + with self._connect() as conn: + cur = conn.cursor() + cur.execute("SELECT authentication_method FROM sessions WHERE session_id=(SELECT current_session())") + res = cur.fetchone() + self.assertEqual(res[0], 'OAuth') + + # Token URL takes precedence over Discovery URL + self._conn_info['oauth_config']['token_url'] = 'invalid_value' + self.assertConnectionFail(err_type=OAuthTokenRefreshError, err_msg='Failed getting OAuth access token from a refresh token.') + exec(AuthenticationTestCase.createPrepStmtClass()) diff --git a/vertica_python/tests/unit_tests/test_parsedsn.py b/vertica_python/tests/unit_tests/test_parsedsn.py index dda09f8f..de13e7aa 100644 --- a/vertica_python/tests/unit_tests/test_parsedsn.py +++ b/vertica_python/tests/unit_tests/test_parsedsn.py @@ -42,6 +42,7 @@ def test_str_arguments(self): 'session_label=vpclient&unicode_error=strict&' 'log_path=/home/admin/vClient.log&log_level=DEBUG&' 'oauth_access_token=GciOiJSUzI1NiI&' + 'oauth_refresh_token=1WS5TLhonGfARN5&' 'workload=python_test_workload&' 'kerberos_service_name=krb_service&kerberos_host_name=krb_host') expected = {'database': 'db1', 'host': 'localhost', 'user': 'john', @@ -49,6 +50,7 @@ def test_str_arguments(self): 'session_label': 'vpclient', 'unicode_error': 'strict', 'log_path': '/home/admin/vClient.log', 'oauth_access_token': 'GciOiJSUzI1NiI', + 'oauth_refresh_token': '1WS5TLhonGfARN5', 'workload': 'python_test_workload', 'kerberos_service_name': 'krb_service', 'kerberos_host_name': 'krb_host'} diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index d9aededb..6cfe207f 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -57,6 +57,7 @@ from .. import errors from ..vertica import messages from ..vertica.cursor import Cursor +from ..vertica.oauth_manager import OAuthManager from ..vertica.messages.message import BackendMessage, FrontendMessage from ..vertica.messages.frontend_messages import CancelRequest from ..vertica.log import VerticaLogging @@ -73,6 +74,7 @@ DEFAULT_BINARY_TRANSFER = False DEFAULT_REQUEST_COMPLEX_TYPES = True DEFAULT_OAUTH_ACCESS_TOKEN = '' +DEFAULT_OAUTH_REFRESH_TOKEN = '' DEFAULT_WORKLOAD = '' try: DEFAULT_USER = getpass.getuser() @@ -265,6 +267,7 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self.transaction_status = None self.socket = None self.socket_as_file = None + self.oauth_manager = None options = options or {} self.options = parse_dsn(options['dsn']) if 'dsn' in options else {} @@ -286,16 +289,8 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self.options.setdefault('host', DEFAULT_HOST) self.options.setdefault('port', DEFAULT_PORT) - if 'user' not in self.options: - if DEFAULT_USER: - self.options['user'] = DEFAULT_USER - else: - msg = 'Connection option "user" is required' - self._logger.error(msg) - raise KeyError(msg) self.options.setdefault('database', DEFAULT_DATABASE) self.options.setdefault('password', DEFAULT_PASSWORD) - self.options.setdefault('oauth_access_token', DEFAULT_OAUTH_ACCESS_TOKEN) self.options.setdefault('autocommit', DEFAULT_AUTOCOMMIT) self.options.setdefault('session_label', _generate_session_label()) self.options.setdefault('backup_server_node', DEFAULT_BACKUP_SERVER_NODE) @@ -309,6 +304,28 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self.address_list = _AddressList(self.options['host'], self.options['port'], self.options['backup_server_node'], self._logger) + # OAuth authentication setup + self.options.setdefault('oauth_access_token', DEFAULT_OAUTH_ACCESS_TOKEN) + self.options.setdefault('oauth_refresh_token', DEFAULT_OAUTH_REFRESH_TOKEN) + for option in ('oauth_access_token', 'oauth_refresh_token'): + if not isinstance(self.options[option], str): + raise TypeError(f'The value of connection option "{option}" should be a str') + self.oauth_access_token = self.options['oauth_access_token'] + if len(self.options['oauth_refresh_token']) > 0: + self.oauth_manager = OAuthManager(self.options['oauth_refresh_token']) + self.oauth_manager.set_config(self.options.get('oauth_config', {})) + + # user is required for non-OAuth connections + if 'user' not in self.options: + if len(self.oauth_access_token) > 0 or len(self.options['oauth_refresh_token']) > 0: + self.options['user'] = '' + elif DEFAULT_USER: + self.options['user'] = DEFAULT_USER + else: + msg = 'Connection option "user" is required' + self._logger.error(msg) + raise KeyError(msg) + # we only support one cursor per connection self.options.setdefault('unicode_error', None) self._cursor = Cursor(self, self._logger, cursor_type=None, @@ -337,7 +354,11 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self._logger.info('Connecting as user "{}" to database "{}" on host "{}" with port {}'.format( self.options['user'], self.options['database'], self.options['host'], self.options['port'])) - self.startup_connection() + + while True: + need_retry = self.startup_connection() + if not need_retry: + break # Complex types metadata is returned since protocol version 3.12 self.complex_types_enabled = self.parameters.get('protocol_version', 0) >= (3 << 16 | 12) and \ @@ -854,7 +875,7 @@ def make_GSS_authentication(self) -> None: self._logger.error(msg) raise errors.KerberosError(msg) - def startup_connection(self) -> None: + def startup_connection(self) -> bool: user = self.options['user'] database = self.options['database'] session_label = self.options['session_label'] @@ -863,9 +884,8 @@ def startup_connection(self) -> None: autocommit = self.options['autocommit'] binary_transfer = self.options['binary_transfer'] request_complex_types = self.options['request_complex_types'] - oauth_access_token = self.options['oauth_access_token'] workload = self.options['workload'] - if len(oauth_access_token) > 0: + if len(self.oauth_access_token) > 0 or len(self.options['oauth_refresh_token']) > 0: auth_category = 'OAuth' elif self.kerberos_is_set: auth_category = 'Kerberos' @@ -875,7 +895,7 @@ def startup_connection(self) -> None: auth_category = '' self.write(messages.Startup(user, database, session_label, os_user_name, autocommit, binary_transfer, - request_complex_types, oauth_access_token, workload, auth_category)) + request_complex_types, self.oauth_access_token, workload, auth_category)) while True: message = self.read_message() @@ -895,7 +915,13 @@ def startup_connection(self) -> None: elif message.code == messages.Authentication.GSS: self.make_GSS_authentication() elif message.code == messages.Authentication.OAUTH: - self.write(messages.Password(oauth_access_token, message.code)) + if self.oauth_manager: + self.oauth_manager.set_config(message.config, not_set_only=True) + # If access token is not set, will attempt to set a new one by using token refresh + if len(self.oauth_access_token) == 0 and self.oauth_manager and not self.oauth_manager.refresh_attempted: + self._logger.info("Issuing an OAuth access token using a refresh token") + self.oauth_access_token = self.oauth_manager.do_token_refresh() + self.write(messages.Password(self.oauth_access_token, message.code)) else: self.write(messages.Password(password, message.code, {'user': user, @@ -908,8 +934,17 @@ def startup_connection(self) -> None: break elif isinstance(message, messages.ErrorResponse): self._logger.error(message.error_message()) + # If this is an OAuth connection and the first connection failed, refresh the access token and try again + if message.sqlstate == '28000' and self.oauth_manager and not self.oauth_manager.refresh_attempted: + if message.error_code in ('2248', '3781', '4131'): + raise errors.ConnectionError("Did not receive proper OAuth Authentication response from server. Please upgrade to the latest Vertica server for OAuth Support.") + self.close_socket() + self._logger.info("Issuing a new OAuth access token using a refresh token") + self.oauth_access_token = self.oauth_manager.do_token_refresh() + return True raise errors.ConnectionError(message.error_message()) else: msg = "Received unexpected startup message: {0}".format(message) self._logger.error(msg) raise errors.MessageError(msg) + return False diff --git a/vertica_python/vertica/messages/backend_messages/authentication.py b/vertica_python/vertica/messages/backend_messages/authentication.py index f1f8cbc0..d125bd8b 100644 --- a/vertica_python/vertica/messages/backend_messages/authentication.py +++ b/vertica_python/vertica/messages/backend_messages/authentication.py @@ -35,7 +35,7 @@ from __future__ import print_function, division, absolute_import, annotations -from struct import unpack +from struct import unpack, unpack_from from ..message import BackendMessage from .... import errors @@ -78,6 +78,29 @@ def __init__(self, data): self.usersalt = unpack('!{0}s'.format(userSaltLen), other[8:])[0] elif self.code in [self.GSS_CONTINUE]: self.auth_data = other + elif self.code == self.OAUTH: + self.config = {} + num_of_fields = other.count(b'\x00') + # Since protocol v3.15 + if num_of_fields >= 3: + pos = 0 + auth_url = unpack_from("!{0}sx".format(other.find(b'\x00', pos) - pos), other, pos)[0] + pos += len(auth_url) + 1 + self.config['auth_url'] = auth_url.decode('utf-8') + token_url = unpack_from("!{0}sx".format(other.find(b'\x00', pos) - pos), other, pos)[0] + pos += len(token_url) + 1 + self.config['token_url'] = token_url.decode('utf-8') + client_id = unpack_from("!{0}sx".format(other.find(b'\x00', pos) - pos), other, pos)[0] + pos += len(client_id) + 1 + self.config['client_id'] = client_id.decode('utf-8') + # Since protocol v3.16 + if num_of_fields == 5: + scope = unpack_from("!{0}sx".format(other.find(b'\x00', pos) - pos), other, pos)[0] + pos += len(scope) + 1 + self.config['scope'] = scope.decode('utf-8') + validate_hostname = unpack_from("!{0}sx".format(other.find(b'\x00', pos) - pos), other, pos)[0] + pos += len(validate_hostname) + 1 + self.config['validate_hostname'] = validate_hostname.decode('utf-8') def __str__(self): return "Authentication: type={}".format(self.code) diff --git a/vertica_python/vertica/oauth_manager.py b/vertica_python/vertica/oauth_manager.py new file mode 100644 index 00000000..87e39da9 --- /dev/null +++ b/vertica_python/vertica/oauth_manager.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024 Open Text. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, division, absolute_import, annotations + +import requests +import warnings + +from ..errors import OAuthConfigurationError, OAuthEndpointDiscoveryError, OAuthTokenRefreshError + + +class OAuthManager: + def __init__(self, refresh_token): + self.refresh_token = refresh_token + self.client_id = "" + self.client_secret = "" + self.token_url = "" + self.discovery_url = "" + self.scope = "" + self.validate_cert_hostname = None + self.refresh_attempted = False + + def set_config(self, configs, not_set_only=False) -> None: + valid_keys = {'refresh_token', 'client_id', 'client_secret', 'token_url', 'discovery_url', + 'scope', 'validate_hostname', 'auth_url'} + try: + for k, v in configs.items(): + if k not in valid_keys: + invalid_key = f'Unrecognized OAuth config property: {k}' + warnings.warn(invalid_key) + continue + if v is None or v == "": # ignore empty value + continue + if k == 'refresh_token' and not (not_set_only and self.refresh_token): + self.refresh_token = str(v) + elif k == 'client_id' and not (not_set_only and self.client_id): + self.client_id = str(v) + elif k == 'client_secret' and not (not_set_only and self.client_secret): + self.client_secret = str(v) + elif k == 'token_url' and not (not_set_only and self.token_url): + self.token_url = str(v) + elif k == 'discovery_url' and not (not_set_only and self.discovery_url): + self.discovery_url = str(v) + elif k == 'scope' and not (not_set_only and self.scope): + self.scope = str(v) + elif k == 'validate_hostname' and not (not_set_only and self.validate_cert_hostname is not None): + self.validate_cert_hostname = bool(v) + except Exception as e: + raise OAuthConfigurationError('Failed setting OAuth configuration.') from e + + def get_access_token_using_refresh_token(self) -> str: + """Issue a new access token using a valid refresh token.""" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + "Expires": "0", + } + params = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + } + if self.scope: + params["scope"] = self.scope + err_msg = 'Failed getting OAuth access token from a refresh token.' + try: + # TODO handle self.validate_cert_hostname + response = requests.post(self.token_url, headers=headers, data=params, verify=False) + response.raise_for_status() + return response.json()["access_token"] + except requests.exceptions.HTTPError as err: + msg = f'{err_msg}\n{err}\n{response.json()}' + raise OAuthTokenRefreshError(msg) + except Exception as e: + raise OAuthTokenRefreshError(err_msg) from e + + def get_token_url_from_discovery_url(self) -> str: + try: + headers = { + "Cache-Control": "no-cache", + "Pragma": "no-cache", + "Expires": "0", + } + # TODO handle self.validate_cert_hostname + response = requests.get(self.discovery_url, headers=headers, verify=False) + response.raise_for_status() + return response.json()["token_endpoint"] + except Exception as e: + err_msg = 'Failed getting token url from discovery url.' + raise OAuthEndpointDiscoveryError(err_msg) from e + + def do_token_refresh(self) -> str: + self.refresh_attempted = True + + if len(self.token_url) == 0 and len(self.discovery_url) == 0: + raise OAuthTokenRefreshError('Token URL or Discovery URL must be set.') + if len(self.client_id) == 0: + raise OAuthTokenRefreshError('OAuth client id is missing.') + if len(self.client_secret) == 0: + raise OAuthTokenRefreshError('OAuth client secret is missing.') + if len(self.refresh_token) == 0: + raise OAuthTokenRefreshError('OAuth refresh token is missing.') + + # If the token url is not set, get it from the discovery url + if len(self.token_url) == 0: + self.token_url = self.get_token_url_from_discovery_url() + + return self.get_access_token_using_refresh_token() + +