Skip to content

Commit dfbe285

Browse files
committed
cloud sql connect add backup
1 parent d547216 commit dfbe285

File tree

2 files changed

+75
-82
lines changed

2 files changed

+75
-82
lines changed

agent/memory_prototyper.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
knn_search_error_full_with_norm,
1515
update_stats_from_buffer,
1616
maybe_register_successful_fix,
17+
cloud_sql_connect_smart
1718
)
18-
from results import BuildResult
19+
from results import BuildResult, Result
1920
from llm_toolkit.text_embedder import VertexEmbeddingModel
2021

2122

@@ -99,7 +100,22 @@ def __init__(self, *args, **kwargs) -> None:
99100
temperature=0.0 # Required by base __init__, but ignored by logic
100101
)
101102

102-
103+
104+
def _initial_prompt(self, results: list[Result]) -> Prompt:
105+
# we do a DB connection check before prompt, to reduce execution time and resource waste
106+
# to delete in real environment
107+
try:
108+
with cloud_sql_connect_smart() as conn:
109+
with conn.cursor() as cursor:
110+
cursor.execute("SHOW TABLES")
111+
result_sql = cursor.fetchall()
112+
logger.info(f"connection successful, query returned: {result_sql} \n continue to " , trial=results[-1].trial)
113+
except Exception as e:
114+
logger.error(f"SQL connection fail, early abort, connection error message is: {e}", trial=results[-1].trial)
115+
raise RuntimeError("Agent execution aborted: Database is unreachable.") from e
116+
return super()._initial_prompt(results)
117+
118+
103119
def chat_llm(self, *args, **kwargs) -> str:
104120
"""Wrapper around Prototyper.chat_llm that also caches raw responses.
105121

memory_helper/cloudsql.py

Lines changed: 57 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def get_credentials():
110110
if hasattr(credentials, "service_account_email"):
111111
# Note: Some SA creds return 'default' or None, so we check for actual content
112112
if credentials.service_account_email and credentials.service_account_email != "default":
113-
#logger.info(f"[AUTH] found service account email {credentials.service_account_email}", "", trial=-1)
114113
db_user = credentials.service_account_email.split('@')[0]
114+
logger.info (f"[AUTH] found service account {db_user}", trial=1)
115115
return db_user
116116

117117

@@ -138,93 +138,70 @@ def get_credentials():
138138
print(f"Failed to auto-detect email: {e}")
139139
raise
140140

141-
142141
@contextmanager
143-
def mysql_connection_no_plain_text_auth():
142+
def cloud_sql_connect_smart():
144143
"""
145-
Context manager that supports:
146-
1. K8s Sidecar with --auto-iam-authn (No password/token needed in Python)
147-
2. Local/Dev Direct Connection (Uses Python Connector with IAM)
144+
Attempts to connect via Local Proxy first.
145+
If that fails, falls back to Google Python Connector.
148146
"""
149147
global _DB_USER
150148
if _DB_USER is None:
151149
_DB_USER = get_credentials()
152-
# we hard-code proxy's IP address and port in the k8s environment
153-
#host = os.getenv("DB_HOST", "127.0.0.1")
154-
#port = int(os.getenv("DB_PORT", "3306")) # Standard MySQL port
155-
db = os.getenv("DB_NAME", "ofg")
156-
connector = Connector(ip_type=IPTypes.PUBLIC, refresh_strategy="LAZY")
150+
157151
conn = None
158-
# hard-code the default option of using proxy to true, the PR exp runs with proxy for anyways.
159-
USE_PROXY = os.getenv("USE_PROXY", True)
152+
connector = None # Only initialized if we use the fallback method
153+
154+
# ---------------------------------------------------------
155+
# Attempt 1: Primary (Local Proxy)
156+
# ---------------------------------------------------------
160157
try:
158+
conn = pymysql.connect(
159+
host="127.0.0.1",
160+
port=3306,
161+
database='ofg',
162+
user=_DB_USER,
163+
password="",
164+
ssl_disabled=True,
165+
connect_timeout=10
166+
)
167+
except Exception as proxy_e:
168+
_log_info(
169+
f"proxy connection failed: {proxy_e}",trial= 1 )
170+
171+
# ---------------------------------------------------------
172+
# Attempt 2: Fallback (Google Connector)
173+
# ---------------------------------------------------------
161174
try:
162-
if USE_PROXY:
163-
conn = pymysql.connect(
164-
host = "127.0.0.1",
165-
port = 3306,
166-
database=db,
167-
user=_DB_USER,
168-
password="", # Proxy handle auth using IAM
169-
ssl_disabled=True, # Proxy handles encryption
170-
connect_timeout=10
171-
)
172-
else:
173-
conn = connector.connect(
174-
INSTANCE_CONNECTION_NAME,
175-
"pymysql",
176-
user=_DB_USER,
177-
db=db,
178-
enable_iam_auth=True
179-
)
180-
except pymysql.err.OperationalError as e:
181-
# Handle MySQL-layer errors with concise, identity-aware messages
182-
error_code = e.args[0]
183-
184-
if error_code == 1045:
185-
# Access Denied
186-
raise RuntimeError(
187-
f"[SQL 1045] Auth failed for DB User '{_DB_USER}'. "
188-
f"Ensure this specific user is added to Cloud SQL 'Users' (IAM)."
189-
) from e
190-
191-
elif error_code == 1049:
192-
# Database Missing
193-
raise RuntimeError(
194-
f"[SQL 1049] Database '{DB_NAME}' not found on instance."
195-
) from e
196-
197-
elif error_code == 2003:
198-
# Connectivity
199-
raise RuntimeError(
200-
f"[SQL 2003] Failed to reach '{INSTANCE_CONNECTION_NAME}' via Public IP."
201-
) from e
202-
203-
else:
204-
# Fallback for other MySQL errors
205-
raise RuntimeError(f"[SQL Error {error_code}] {e}") from e
206-
207-
except Exception as e:
208-
# Handle Google Cloud API-layer errors
209-
raise RuntimeError(
210-
f"[GCP API Error] Connector failed initialization. "
211-
f"Identity '{_DB_USER}' may lack 'Cloud SQL Client' role. Details: {e}"
212-
) from e
213-
214-
# 4. Success Yield
175+
_log_info("fallback to connect with GSA account", trial= 1)
176+
connector = Connector(ip_type=IPTypes.PUBLIC, refresh_strategy="LAZY")
177+
conn = connector.connect(
178+
INSTANCE_CONNECTION_NAME,
179+
"pymysql",
180+
user=_DB_USER,
181+
db="ofg",
182+
enable_iam_auth=True
183+
)
184+
except Exception as connector_e:
185+
_log_info(f"Fallback connection also failed: {connector_e}", trial= 1)
186+
# If the connector was created but connect() failed, close it.
187+
if connector:
188+
connector.close()
189+
raise connector_e # Raise the final error to the runner
190+
191+
# ---------------------------------------------------------
192+
# Phase 3: Yield & Cleanup
193+
# ---------------------------------------------------------
194+
try:
195+
# Yield the successful connection (from either source) to the inner block
215196
yield conn
216-
217197
finally:
218-
# 5. Cleanup
198+
# Cleanup Connection
219199
if conn:
220-
try:
221-
conn.close()
222-
except Exception:
223-
pass
224-
try:
200+
conn.close()
201+
202+
# Cleanup Connector (Only if it was initialized during fallback)
203+
if connector:
225204
connector.close()
226-
except Exception:
227-
pass
228205

229206

230207
"""
@@ -352,7 +329,7 @@ def knn_search_error(
352329
LIMIT %s
353330
"""
354331
rows: List[Dict[str, Any]] = []
355-
with mysql_connection_no_plain_text_auth() as conn:
332+
with cloud_sql_connect_smart() as conn:
356333
with conn.cursor() as cur:
357334
_log_info("[KNN] Executing simple SQL top_k=%d", top_k, trial=trial)
358335
cur.execute(sql, (vec_str, top_k))
@@ -426,7 +403,7 @@ def knn_search_error_dbg(
426403
LIMIT %s
427404
"""
428405
rows: List[Dict[str, Any]] = []
429-
with mysql_connection_no_plain_text_auth() as conn:
406+
with cloud_sql_connect_smart() as conn:
430407
with conn.cursor() as cur:
431408
_log_info("[KNN-DBG] Executing debug SQL top_k=%d", top_k, trial=trial)
432409
cur.execute(sql, (vec_str, top_k))
@@ -544,7 +521,7 @@ def _knn_search_error_full_core(
544521
"""
545522

546523
rows: List[Dict[str, Any]] = []
547-
with mysql_connection_no_plain_text_auth() as conn:
524+
with cloud_sql_connect_smart() as conn:
548525
with conn.cursor() as cur:
549526
_log_info("[KNN] Executing SQL top_k=%d, conf=%s", top_k, confidence_levels, trial=trial)
550527
# Param order: vector_json, *conf_levels, top_k
@@ -711,7 +688,7 @@ def update_stats_from_buffer(
711688
last_used_at = UTC_TIMESTAMP()
712689
"""
713690

714-
with mysql_connection_no_plain_text_auth() as conn:
691+
with cloud_sql_connect_smart() as conn:
715692
with conn.cursor() as cur:
716693
for entry_id, deltas in sorted(stats_buffer.items()):
717694
# Skip pure-zero deltas to avoid useless writes.
@@ -857,7 +834,7 @@ def maybe_register_successful_fix(
857834
# 3) Deduplicate: check if a very similar entry already exists for this project.
858835
DEDUP_THRESHOLD = 0.04 # tune as needed
859836

860-
with mysql_connection_no_plain_text_auth() as conn:
837+
with cloud_sql_connect_smart() as conn:
861838
with conn.cursor() as cur:
862839
dedup_sql = """
863840
SELECT

0 commit comments

Comments
 (0)