|
| 1 | +""" |
| 2 | +Data access helpers for the Query Insights persistence layer. |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import json |
| 8 | +from datetime import datetime |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +import structlog |
| 12 | + |
| 13 | +from cairo_coder.db.models import UserInteraction |
| 14 | +from cairo_coder.db.session import get_pool |
| 15 | + |
| 16 | +logger = structlog.get_logger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def _serialize_json_field(value: Any) -> str | None: |
| 20 | + """ |
| 21 | + Serialize a Python object to JSON string for database storage. |
| 22 | +
|
| 23 | + Args: |
| 24 | + value: Python object to serialize (dict, list, etc.) |
| 25 | +
|
| 26 | + Returns: |
| 27 | + JSON string or None if value is None/empty |
| 28 | + """ |
| 29 | + if value is None: |
| 30 | + return None |
| 31 | + return json.dumps(value) |
| 32 | + |
| 33 | + |
| 34 | +def _normalize_json_field(value: Any, default: Any = None) -> Any: |
| 35 | + """ |
| 36 | + Normalize a JSON field from database (may be string or already parsed). |
| 37 | +
|
| 38 | + Args: |
| 39 | + value: Value from database (string, dict, list, or None) |
| 40 | + default: Default value to use if parsing fails or value is None |
| 41 | +
|
| 42 | + Returns: |
| 43 | + Parsed JSON object or default value |
| 44 | + """ |
| 45 | + if value is None: |
| 46 | + return default |
| 47 | + if isinstance(value, str): |
| 48 | + try: |
| 49 | + return json.loads(value) |
| 50 | + except (json.JSONDecodeError, TypeError): |
| 51 | + return default |
| 52 | + return value |
| 53 | + |
| 54 | + |
| 55 | +def _normalize_row(row: dict | None, fields_with_defaults: dict[str, Any]) -> dict | None: |
| 56 | + """ |
| 57 | + Parse stringified JSON fields in a row dictionary and apply defaults for None values. |
| 58 | +
|
| 59 | + Args: |
| 60 | + row: Dictionary from database row (or None) |
| 61 | + fields_with_defaults: Mapping of field names to default values |
| 62 | +
|
| 63 | + Returns: |
| 64 | + Normalized dictionary with parsed JSON fields, or None if input row is None |
| 65 | + """ |
| 66 | + if row is None: |
| 67 | + return None |
| 68 | + |
| 69 | + d = dict(row) |
| 70 | + for field, default_val in fields_with_defaults.items(): |
| 71 | + d[field] = _normalize_json_field(d.get(field), default_val) |
| 72 | + return d |
| 73 | + |
| 74 | + |
| 75 | +async def create_user_interaction(interaction: UserInteraction) -> None: |
| 76 | + """Persist a user interaction in the database.""" |
| 77 | + pool = await get_pool() |
| 78 | + try: |
| 79 | + async with pool.acquire() as connection: |
| 80 | + await connection.execute( |
| 81 | + """ |
| 82 | + INSERT INTO user_interactions ( |
| 83 | + id, |
| 84 | + agent_id, |
| 85 | + mcp_mode, |
| 86 | + chat_history, |
| 87 | + query, |
| 88 | + generated_answer, |
| 89 | + retrieved_sources, |
| 90 | + llm_usage |
| 91 | + ) |
| 92 | + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) |
| 93 | + """, |
| 94 | + interaction.id, |
| 95 | + interaction.agent_id, |
| 96 | + interaction.mcp_mode, |
| 97 | + _serialize_json_field(interaction.chat_history), |
| 98 | + interaction.query, |
| 99 | + interaction.generated_answer, |
| 100 | + _serialize_json_field(interaction.retrieved_sources), |
| 101 | + _serialize_json_field(interaction.llm_usage), |
| 102 | + ) |
| 103 | + logger.debug("User interaction logged successfully", interaction_id=str(interaction.id)) |
| 104 | + except Exception as exc: # pragma: no cover - defensive logging |
| 105 | + logger.error("Failed to log user interaction", error=str(exc), exc_info=True) |
| 106 | + |
| 107 | + |
| 108 | +async def get_interactions( |
| 109 | + start_date: datetime | None, |
| 110 | + end_date: datetime | None, |
| 111 | + agent_id: str | None, |
| 112 | + limit: int, |
| 113 | + offset: int, |
| 114 | + query_text: str | None = None, |
| 115 | +) -> tuple[list[dict[str, Any]], int]: |
| 116 | + """Fetch paginated interactions matching the supplied filters. |
| 117 | +
|
| 118 | + If start_date and end_date are not provided, returns the last N interactions |
| 119 | + ordered by created_at DESC. |
| 120 | + """ |
| 121 | + pool = await get_pool() |
| 122 | + async with pool.acquire() as connection: |
| 123 | + params: list[Any] = [] |
| 124 | + filters = [] |
| 125 | + |
| 126 | + if start_date is not None: |
| 127 | + params.append(start_date) |
| 128 | + filters.append(f"created_at >= ${len(params)}") |
| 129 | + |
| 130 | + if end_date is not None: |
| 131 | + params.append(end_date) |
| 132 | + filters.append(f"created_at <= ${len(params)}") |
| 133 | + |
| 134 | + if agent_id: |
| 135 | + params.append(agent_id) |
| 136 | + filters.append(f"agent_id = ${len(params)}") |
| 137 | + |
| 138 | + if query_text: |
| 139 | + params.append(f"%{query_text}%") |
| 140 | + filters.append(f"query ILIKE ${len(params)}") |
| 141 | + |
| 142 | + where_clause = "WHERE " + " AND ".join(filters) if filters else "" |
| 143 | + |
| 144 | + count_query = f""" |
| 145 | + SELECT COUNT(*) |
| 146 | + FROM user_interactions |
| 147 | + {where_clause} |
| 148 | + """ |
| 149 | + total = await connection.fetchval(count_query, *params) |
| 150 | + |
| 151 | + params.extend([limit, offset]) |
| 152 | + limit_placeholder = len(params) - 1 |
| 153 | + offset_placeholder = len(params) |
| 154 | + data_query = f""" |
| 155 | + SELECT id, created_at, agent_id, query, chat_history, generated_answer |
| 156 | + FROM user_interactions |
| 157 | + {where_clause} |
| 158 | + ORDER BY created_at DESC |
| 159 | + LIMIT ${limit_placeholder} |
| 160 | + OFFSET ${offset_placeholder} |
| 161 | + """ |
| 162 | + rows = await connection.fetch(data_query, *params) |
| 163 | + |
| 164 | + # Normalize JSON fields that may be returned as strings by asyncpg |
| 165 | + items = [_normalize_row(dict(row), {"chat_history": []}) for row in rows] |
| 166 | + return items, int(total) |
| 167 | + |
| 168 | + |
| 169 | +async def migrate_user_interaction(interaction: UserInteraction) -> tuple[bool, bool]: |
| 170 | + """ |
| 171 | + Persist a user interaction for migration purposes with upsert behavior. |
| 172 | +
|
| 173 | + Uses ON CONFLICT DO UPDATE to override existing entries based on the ID. |
| 174 | + This allows re-running migrations to update data if needed. |
| 175 | +
|
| 176 | + Args: |
| 177 | + interaction: UserInteraction model with pre-set ID from LangSmith |
| 178 | +
|
| 179 | + Returns: |
| 180 | + Tuple of (was_modified, was_inserted) where: |
| 181 | + - was_modified: True if any action was taken (insert or update) |
| 182 | + - was_inserted: True if inserted, False if updated |
| 183 | + """ |
| 184 | + pool = await get_pool() |
| 185 | + try: |
| 186 | + async with pool.acquire() as connection: |
| 187 | + # Single upsert round-trip; infer insert vs update via system column |
| 188 | + row = await connection.fetchrow( |
| 189 | + """ |
| 190 | + INSERT INTO user_interactions ( |
| 191 | + id, |
| 192 | + created_at, |
| 193 | + agent_id, |
| 194 | + mcp_mode, |
| 195 | + chat_history, |
| 196 | + query, |
| 197 | + generated_answer, |
| 198 | + retrieved_sources, |
| 199 | + llm_usage |
| 200 | + ) |
| 201 | + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) |
| 202 | + ON CONFLICT (id) DO UPDATE SET |
| 203 | + created_at = EXCLUDED.created_at, |
| 204 | + agent_id = EXCLUDED.agent_id, |
| 205 | + mcp_mode = EXCLUDED.mcp_mode, |
| 206 | + chat_history = EXCLUDED.chat_history, |
| 207 | + query = EXCLUDED.query, |
| 208 | + generated_answer = EXCLUDED.generated_answer, |
| 209 | + retrieved_sources = EXCLUDED.retrieved_sources, |
| 210 | + llm_usage = EXCLUDED.llm_usage |
| 211 | + RETURNING (xmax = 0) AS inserted |
| 212 | + """, |
| 213 | + interaction.id, |
| 214 | + interaction.created_at, |
| 215 | + interaction.agent_id, |
| 216 | + interaction.mcp_mode, |
| 217 | + _serialize_json_field(interaction.chat_history), |
| 218 | + interaction.query, |
| 219 | + interaction.generated_answer, |
| 220 | + _serialize_json_field(interaction.retrieved_sources), |
| 221 | + _serialize_json_field(interaction.llm_usage), |
| 222 | + ) |
| 223 | + |
| 224 | + if row is None: |
| 225 | + logger.warning("Unexpected: no result from upsert", interaction_id=str(interaction.id)) |
| 226 | + return False, False |
| 227 | + |
| 228 | + was_inserted = bool(row["inserted"]) if "inserted" in row else False |
| 229 | + if was_inserted: |
| 230 | + logger.debug("User interaction inserted", interaction_id=str(interaction.id)) |
| 231 | + else: |
| 232 | + logger.debug("User interaction updated", interaction_id=str(interaction.id)) |
| 233 | + return True, was_inserted |
| 234 | + except Exception as exc: # pragma: no cover - defensive logging |
| 235 | + logger.error("Failed to migrate user interaction", error=str(exc), exc_info=True) |
| 236 | + raise |
| 237 | + |
0 commit comments