Skip to content

Commit 6a4e974

Browse files
committed
✨ Add SSL support for database connections and refactor connection configuration handling
1 parent e2e190e commit 6a4e974

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

src/database/config/__init__.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) NiceBots
22
# SPDX-License-Identifier: MIT
33

4+
import ssl
45
from collections import defaultdict
56
from logging import getLogger
7+
from pathlib import Path
68
from typing import Any
79

810
import aerich
@@ -13,12 +15,40 @@
1315
logger = getLogger("bot").getChild("database")
1416

1517

18+
def create_ssl_context() -> ssl.SSLContext:
19+
"""Create SSL context for Database connection."""
20+
# Try to find the certificate in common locations
21+
cert_paths = [
22+
Path(__file__).parent.parent.parent / ".postgres" / "root.crt", # Docker/production
23+
Path.cwd() / ".postgres" / "root.crt", # Local dev
24+
]
25+
26+
cert_path = None
27+
for path in cert_paths:
28+
if path.exists():
29+
cert_path = path
30+
break
31+
32+
if cert_path is None:
33+
logger.warning("Database certificate not found, SSL verification will fail")
34+
return True # fallback to basic SSL
35+
36+
logger.info(f"Using SSL certificate: {cert_path}")
37+
ssl_context = ssl.create_default_context(cafile=str(cert_path))
38+
ssl_context.check_hostname = False
39+
ssl_context.verify_mode = ssl.CERT_REQUIRED
40+
return ssl_context
41+
42+
1643
def apply_params(uri: str, params: dict[str, Any] | None) -> str:
1744
if params is None:
1845
return uri
1946

2047
first: bool = True
2148
for param, value in params.items():
49+
# Skip 'ssl' param as we'll handle it separately
50+
if param == "ssl":
51+
continue
2252
if value is not None:
2353
uri += f"{'?' if first else '&'}{param}={value}"
2454
first = False
@@ -42,26 +72,35 @@ def get_url_apps_mapping() -> dict[str, list[str]]:
4272
return mapping
4373

4474

45-
def parse_url_apps_mapping(url_apps_mapping: dict[str, list[str]]) -> tuple[dict[str, str], dict[str, str]]:
75+
def parse_url_apps_mapping(url_apps_mapping: dict[str, list[str]]) -> tuple[dict[str, str], dict[str, dict[str, Any]]]:
4676
app_connection: dict[str, str] = {}
47-
connection_url: dict[str, str] = {}
77+
connection_config: dict[str, dict[str, Any]] = {}
78+
79+
# Create SSL context once if needed
80+
ssl_context = create_ssl_context() if config.db.params and config.db.params.get("ssl") else None
4881

4982
for i, (url, apps) in enumerate(url_apps_mapping.items()):
5083
connection_name = f"connection_{i}"
51-
connection_url[connection_name] = url
84+
connection_config[connection_name] = {"engine": "tortoise.backends.asyncpg", "credentials": {"dsn": url}}
85+
if ssl_context:
86+
connection_config[connection_name]["credentials"]["ssl"] = ssl_context
87+
5288
for app in apps:
5389
app_connection[app] = connection_name
5490

5591
app_connection["models"] = "default"
56-
connection_url["default"] = apply_params(config.db.url, config.db.params)
92+
default_url = apply_params(config.db.url, config.db.params)
93+
connection_config["default"] = {"engine": "tortoise.backends.asyncpg", "credentials": {"dsn": default_url}}
94+
if ssl_context:
95+
connection_config["default"]["credentials"]["ssl"] = ssl_context
5796

58-
return app_connection, connection_url
97+
return app_connection, connection_config
5998

6099

61100
APP_CONNECTION_MAPPING: dict[str, str]
62-
CONNECTION_URL_MAPPING: dict[str, str]
101+
CONNECTION_CONFIG_MAPPING: dict[str, dict[str, Any]]
63102

64-
APP_CONNECTION_MAPPING, CONNECTION_URL_MAPPING = parse_url_apps_mapping(get_url_apps_mapping()) # pyright: ignore[reportConstantRedefinition]
103+
APP_CONNECTION_MAPPING, CONNECTION_CONFIG_MAPPING = parse_url_apps_mapping(get_url_apps_mapping())
65104

66105

67106
def get_apps() -> dict[str, dict[str, list[str] | str]]:
@@ -80,7 +119,7 @@ def get_apps() -> dict[str, dict[str, list[str] | str]]:
80119

81120

82121
TORTOISE_ORM = {
83-
"connections": CONNECTION_URL_MAPPING,
122+
"connections": CONNECTION_CONFIG_MAPPING,
84123
"apps": get_apps(),
85124
}
86125

@@ -93,7 +132,7 @@ async def init() -> None:
93132
)
94133
await command.init()
95134
migrated = await command.upgrade(run_in_transaction=True)
96-
logger.success(f"Successfully migrated {migrated} migrations") # pyright: ignore [reportAttributeAccessIssue]
135+
logger.success(f"Successfully migrated {migrated} migrations")
97136
await Tortoise.init(config=TORTOISE_ORM)
98137

99138

0 commit comments

Comments
 (0)