11# Copyright (c) NiceBots
22# SPDX-License-Identifier: MIT
33
4+ import ssl
45from collections import defaultdict
56from logging import getLogger
7+ from pathlib import Path
68from typing import Any
79
810import aerich
1315logger = getLogger ("bot" ).getChild ("database" )
1416
1517
18+ def create_ssl_context () -> ssl .SSLContext :
19+ """Create SSL context for Database connection."""
20+ # Try to find the certificate in common locations
21+ cert_paths = [
22+ Path (__file__ ).parent .parent .parent / ".postgres" / "root.crt" , # Docker/production
23+ Path .cwd () / ".postgres" / "root.crt" , # Local dev
24+ ]
25+
26+ cert_path = None
27+ for path in cert_paths :
28+ if path .exists ():
29+ cert_path = path
30+ break
31+
32+ if cert_path is None :
33+ logger .warning ("Database certificate not found, SSL verification will fail" )
34+ return True # fallback to basic SSL
35+
36+ logger .info (f"Using SSL certificate: { cert_path } " )
37+ ssl_context = ssl .create_default_context (cafile = str (cert_path ))
38+ ssl_context .check_hostname = False
39+ ssl_context .verify_mode = ssl .CERT_REQUIRED
40+ return ssl_context
41+
42+
1643def apply_params (uri : str , params : dict [str , Any ] | None ) -> str :
1744 if params is None :
1845 return uri
1946
2047 first : bool = True
2148 for param , value in params .items ():
49+ # Skip 'ssl' param as we'll handle it separately
50+ if param == "ssl" :
51+ continue
2252 if value is not None :
2353 uri += f"{ '?' if first else '&' } { param } ={ value } "
2454 first = False
@@ -42,26 +72,35 @@ def get_url_apps_mapping() -> dict[str, list[str]]:
4272 return mapping
4373
4474
45- def parse_url_apps_mapping (url_apps_mapping : dict [str , list [str ]]) -> tuple [dict [str , str ], dict [str , str ]]:
75+ def parse_url_apps_mapping (url_apps_mapping : dict [str , list [str ]]) -> tuple [dict [str , str ], dict [str , dict [ str , Any ] ]]:
4676 app_connection : dict [str , str ] = {}
47- connection_url : dict [str , str ] = {}
77+ connection_config : dict [str , dict [str , Any ]] = {}
78+
79+ # Create SSL context once if needed
80+ ssl_context = create_ssl_context () if config .db .params and config .db .params .get ("ssl" ) else None
4881
4982 for i , (url , apps ) in enumerate (url_apps_mapping .items ()):
5083 connection_name = f"connection_{ i } "
51- connection_url [connection_name ] = url
84+ connection_config [connection_name ] = {"engine" : "tortoise.backends.asyncpg" , "credentials" : {"dsn" : url }}
85+ if ssl_context :
86+ connection_config [connection_name ]["credentials" ]["ssl" ] = ssl_context
87+
5288 for app in apps :
5389 app_connection [app ] = connection_name
5490
5591 app_connection ["models" ] = "default"
56- connection_url ["default" ] = apply_params (config .db .url , config .db .params )
92+ default_url = apply_params (config .db .url , config .db .params )
93+ connection_config ["default" ] = {"engine" : "tortoise.backends.asyncpg" , "credentials" : {"dsn" : default_url }}
94+ if ssl_context :
95+ connection_config ["default" ]["credentials" ]["ssl" ] = ssl_context
5796
58- return app_connection , connection_url
97+ return app_connection , connection_config
5998
6099
61100APP_CONNECTION_MAPPING : dict [str , str ]
62- CONNECTION_URL_MAPPING : dict [str , str ]
101+ CONNECTION_CONFIG_MAPPING : dict [str , dict [ str , Any ] ]
63102
64- APP_CONNECTION_MAPPING , CONNECTION_URL_MAPPING = parse_url_apps_mapping (get_url_apps_mapping ()) # pyright: ignore[reportConstantRedefinition]
103+ APP_CONNECTION_MAPPING , CONNECTION_CONFIG_MAPPING = parse_url_apps_mapping (get_url_apps_mapping ())
65104
66105
67106def get_apps () -> dict [str , dict [str , list [str ] | str ]]:
@@ -80,7 +119,7 @@ def get_apps() -> dict[str, dict[str, list[str] | str]]:
80119
81120
82121TORTOISE_ORM = {
83- "connections" : CONNECTION_URL_MAPPING ,
122+ "connections" : CONNECTION_CONFIG_MAPPING ,
84123 "apps" : get_apps (),
85124}
86125
@@ -93,7 +132,7 @@ async def init() -> None:
93132 )
94133 await command .init ()
95134 migrated = await command .upgrade (run_in_transaction = True )
96- logger .success (f"Successfully migrated { migrated } migrations" ) # pyright: ignore [reportAttributeAccessIssue]
135+ logger .success (f"Successfully migrated { migrated } migrations" )
97136 await Tortoise .init (config = TORTOISE_ORM )
98137
99138
0 commit comments