Skip to content

Commit 25b9652

Browse files
committed
refactor: implements tls support for certificate based auth
adds initial support using ssl_context for certificate based auth also adds testsuite to use the ssl context from env varables, see docs for relevant how to use pending: add test suite to run without certificates refs #66
1 parent 2b2081a commit 25b9652

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

gallagher/cc/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
# to obtain an API key
3030
api_key: str = None
3131

32+
# Certificate file to be used for authentication
33+
file_tls_certificate: str = None
34+
35+
# Private key file to be used for authentication
36+
file_private_key: str = None
37+
3238
# By default the base API is set to the Australian Gateway
3339
# Override this with the US gateway or a local DNS/IP address
3440
api_base: str = URL.CLOUD_GATEWAY_AU

gallagher/cc/core.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from http import HTTPStatus # Provides constants for HTTP status codes
2828

29+
import ssl
2930
import httpx
3031

3132
from . import proxy as proxy_address
@@ -232,6 +233,23 @@ async def get_config(cls) -> EndpointConfig:
232233
provide additional configuration options.
233234
"""
234235
raise NotImplementedError("get_config method not implemented")
236+
237+
@classmethod
238+
def _ssl_context(cls):
239+
"""Returns the SSL context for the endpoint
240+
241+
This method can be overridden by the child class to
242+
provide additional SSL context options.
243+
"""
244+
from . import file_tls_certificate, file_private_key
245+
246+
if not file_tls_certificate:
247+
"""TLS certificate is required for SSL context"""
248+
return None
249+
250+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
251+
context.load_cert_chain(file_tls_certificate, file_private_key)
252+
return context
235253

236254
@classmethod
237255
async def _discover(cls):
@@ -269,7 +287,10 @@ async def _discover(cls):
269287
# be called as part of the bootstrapping process
270288
from . import api_base
271289

272-
async with httpx.AsyncClient(proxy=proxy_address) as _httpx_async:
290+
async with httpx.AsyncClient(
291+
proxy=proxy_address,
292+
verify=cls._ssl_context(),
293+
) as _httpx_async:
273294
# Don't use the _get wrapper here, we need to get the raw response
274295
response = await _httpx_async.get(
275296
api_base,
@@ -490,7 +511,11 @@ async def follow(
490511
# Initial url is set to endpoint_follow
491512
url = f"{cls.__config__.endpoint_follow.href}"
492513

493-
async with httpx.AsyncClient(proxy=proxy_address) as _httpx_async:
514+
async with httpx.AsyncClient(
515+
proxy=proxy_address,
516+
verify=cls._ssl_context(),
517+
) as _httpx_async:
518+
494519
while event.is_set():
495520
try:
496521
response = await _httpx_async.get(
@@ -546,7 +571,10 @@ async def _get(
546571
:param str url: URL to fetch the data from
547572
:param AppBaseModel response_class: DTO to be used for list requests
548573
"""
549-
async with httpx.AsyncClient(proxy=proxy_address) as _httpx_async:
574+
async with httpx.AsyncClient(
575+
proxy=proxy_address,
576+
verify=cls._ssl_context(),
577+
) as _httpx_async:
550578

551579
try:
552580

@@ -594,7 +622,10 @@ async def _post(
594622
The behaviour is very similar to the _get method, except
595623
parsing and sending out a body as part of the request.
596624
"""
597-
async with httpx.AsyncClient(proxy=proxy_address) as _httpx_async:
625+
async with httpx.AsyncClient(
626+
proxy=proxy_address,
627+
verify=cls._ssl_context(),
628+
) as _httpx_async:
598629

599630
try:
600631

tests/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TODO: check if setup and teardown can be turned into async
1313
1414
"""
15+
import tempfile
1516

1617

1718
def setup_module(module):
@@ -25,9 +26,33 @@ def setup_module(module):
2526

2627
api_key = os.environ.get("GACC_API_KEY")
2728

29+
# Read these from the environment variables, if they exists
30+
# they will be written to temporary files
31+
certificate_anomaly = os.environ.get("CERTIFICATE_ANOMALY")
32+
private_key_anomaly = os.environ.get("PRIVATE_KEY_ANOMALY")
33+
34+
# Create temporary files to store the certificate and private key
35+
temp_file_certificate = tempfile.NamedTemporaryFile(
36+
suffix=".crt",
37+
delete=False
38+
)
39+
temp_file_private_key = tempfile.NamedTemporaryFile(
40+
suffix=".key",
41+
delete=False
42+
)
43+
44+
# Write the certificate and private key to temporary files
45+
if certificate_anomaly and temp_file_certificate:
46+
temp_file_certificate.write(certificate_anomaly.encode('utf-8'))
47+
48+
if private_key_anomaly and temp_file_private_key:
49+
temp_file_private_key.write(private_key_anomaly.encode('utf-8'))
50+
2851
from gallagher import cc
2952

3053
cc.api_key = api_key
54+
cc.file_tls_certificate = temp_file_certificate.name
55+
cc.file_private_key = temp_file_private_key.name
3156

3257

3358
def teardown_module(module):

0 commit comments

Comments
 (0)