Skip to content

Commit 2292592

Browse files
committed
add proper testing
1 parent 0a90891 commit 2292592

File tree

11 files changed

+737
-180
lines changed

11 files changed

+737
-180
lines changed

API_DOCUMENTATION.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,11 @@ The Query Insights API exposes raw interaction logs and lightweight analytics fo
360360

361361
Fetch paginated user queries within a specific window.
362362

363-
- `start_date` *(ISO 8601, required)* — inclusive lower bound.
364-
- `end_date` *(ISO 8601, required)* — inclusive upper bound.
365-
- `agent_id` *(optional)* — filter by agent id when provided.
366-
- `limit` *(default `100`)* — maximum rows returned.
367-
- `offset` *(default `0`)* — pagination offset.
363+
- `start_date` _(ISO 8601, required)_ — inclusive lower bound.
364+
- `end_date` _(ISO 8601, required)_ — inclusive upper bound.
365+
- `agent_id` _(optional)_ — filter by agent id when provided.
366+
- `limit` _(default `100`)_ — maximum rows returned.
367+
- `offset` _(default `0`)_ — pagination offset.
368368

369369
**Response** `200 OK`
370370

@@ -388,7 +388,7 @@ Fetch paginated user queries within a specific window.
388388

389389
Trigger an asynchronous analysis job. The response returns immediately with the job identifier; the analysis runs in the background.
390390

391-
**Request**
391+
#### Request
392392

393393
```json
394394
{

TASK_PRD.md

Lines changed: 0 additions & 135 deletions
This file was deleted.

python/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ strict_optional = true
141141
testpaths = ["tests"]
142142
pythonpath = ["src"]
143143
asyncio_mode = "auto"
144+
markers = [
145+
"db: marks tests that require a database (run by default, use -m 'not db' to skip)",
146+
]
144147
filterwarnings = [
145148
"ignore::DeprecationWarning",
146149
"ignore::PendingDeprecationWarning",

python/src/cairo_coder/db/repository.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@
1717
logger = structlog.get_logger(__name__)
1818

1919

20+
def _normalize_row(row: dict | None, fields_with_defaults: dict[str, Any]) -> dict | None:
21+
"""
22+
Parse stringified JSON fields in a row dictionary and apply defaults for None values.
23+
24+
Args:
25+
row: Dictionary from database row (or None)
26+
fields_with_defaults: Mapping of field names to default values
27+
28+
Returns:
29+
Normalized dictionary with parsed JSON fields, or None if input row is None
30+
"""
31+
if row is None:
32+
return None
33+
34+
d = dict(row)
35+
for field, default_val in fields_with_defaults.items():
36+
val = d.get(field)
37+
if isinstance(val, str):
38+
try:
39+
d[field] = json.loads(val)
40+
except (json.JSONDecodeError, TypeError):
41+
d[field] = default_val
42+
elif val is None:
43+
d[field] = default_val
44+
return d
45+
46+
2047
async def create_user_interaction(interaction: UserInteraction) -> None:
2148
"""Persist a user interaction in the database."""
2249
pool = await get_pool()
@@ -89,22 +116,24 @@ async def get_interactions(
89116
OFFSET ${offset_placeholder}
90117
"""
91118
rows = await connection.fetch(data_query, *params)
92-
return [dict(row) for row in rows], int(total)
119+
120+
# Normalize JSON fields that may be returned as strings by asyncpg
121+
items = [_normalize_row(dict(row), {"chat_history": []}) for row in rows]
122+
return items, int(total)
93123

94124

95125
async def create_analysis_job(params: dict[str, Any]) -> uuid.UUID:
96126
"""Insert a new analysis job and return its identifier."""
97127
pool = await get_pool()
98128
async with pool.acquire() as connection:
99-
job_id = await connection.fetchval(
129+
return await connection.fetchval(
100130
"""
101131
INSERT INTO query_analyses (status, analysis_parameters)
102132
VALUES ('pending', $1)
103133
RETURNING id
104134
""",
105135
json.dumps(params),
106136
)
107-
return job_id
108137

109138

110139
async def update_analysis_job(
@@ -144,7 +173,7 @@ async def get_analysis_jobs(limit: int = 100) -> list[dict[str, Any]]:
144173
""",
145174
limit,
146175
)
147-
return [dict(row) for row in rows]
176+
return [_normalize_row(dict(row), {"analysis_parameters": {}}) for row in rows]
148177

149178

150179
async def get_analysis_job_by_id(job_id: uuid.UUID) -> dict[str, Any] | None:
@@ -164,4 +193,7 @@ async def get_analysis_job_by_id(job_id: uuid.UUID) -> dict[str, Any] | None:
164193
""",
165194
job_id,
166195
)
167-
return dict(row) if row is not None else None
196+
return _normalize_row(
197+
dict(row) if row else None,
198+
{"analysis_parameters": {}, "analysis_result": None}
199+
)

python/src/cairo_coder/db/session.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,29 @@
44

55
from __future__ import annotations
66

7+
import asyncio
8+
79
import asyncpg
810
import structlog
911

1012
from cairo_coder.config.manager import ConfigManager
1113

1214
logger = structlog.get_logger(__name__)
1315

14-
pool: asyncpg.Pool | None = None
16+
# Maintain one pool per running event loop to avoid cross-loop usage issues
17+
pools: dict[int, asyncpg.Pool] = {}
1518

1619

1720
async def get_pool() -> asyncpg.Pool:
18-
"""Return the global asyncpg connection pool, lazily creating it."""
19-
global pool
21+
"""Return an asyncpg connection pool bound to the current event loop.
22+
23+
FastAPI's TestClient and AnyIO can run application code across different
24+
event loops. Using a single cached pool may lead to cross-loop errors.
25+
To prevent this, we maintain a pool per loop.
26+
"""
27+
loop = asyncio.get_running_loop()
28+
key = id(loop)
29+
pool = pools.get(key)
2030
if pool is None:
2131
config = ConfigManager.load_config()
2232
try:
@@ -25,6 +35,7 @@ async def get_pool() -> asyncpg.Pool:
2535
min_size=2,
2636
max_size=10,
2737
)
38+
pools[key] = pool
2839
logger.info("Database connection pool created successfully.")
2940
except Exception as exc: # pragma: no cover - defensive logging
3041
logger.error("Failed to create database connection pool", error=str(exc))
@@ -34,11 +45,15 @@ async def get_pool() -> asyncpg.Pool:
3445

3546
async def close_pool() -> None:
3647
"""Close the asyncpg connection pool if it is active."""
37-
global pool
38-
if pool is not None:
39-
await pool.close()
40-
pool = None
41-
logger.info("Database connection pool closed.")
48+
import contextlib
49+
50+
# Close and clear all pools
51+
global pools
52+
for p in list(pools.values()):
53+
with contextlib.suppress(Exception):
54+
await p.close()
55+
pools.clear()
56+
logger.info("Database connection pool(s) closed.")
4257

4358

4459
async def execute_schema_scripts() -> None:

python/src/cairo_coder/server/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
from cairo_coder.core.config import VectorStoreConfig
2929
from cairo_coder.core.rag_pipeline import RagPipeline
3030
from cairo_coder.core.types import Message, Role, StreamEventType
31-
from cairo_coder.dspy.document_retriever import SourceFilteredPgVectorRM
32-
from cairo_coder.dspy.suggestion_program import SuggestionGeneration
3331
from cairo_coder.db import session as db_session
3432
from cairo_coder.db.models import UserInteraction
3533
from cairo_coder.db.repository import create_user_interaction
34+
from cairo_coder.dspy.document_retriever import SourceFilteredPgVectorRM
35+
from cairo_coder.dspy.suggestion_program import SuggestionGeneration
3636
from cairo_coder.server.insights_api import router as insights_router
3737
from cairo_coder.utils.logging import setup_logging
3838

0 commit comments

Comments
 (0)