|
9 | 9 | import torch |
10 | 10 | import pickle |
11 | 11 | import json |
| 12 | +import csv |
| 13 | +from collections import defaultdict |
12 | 14 |
|
13 | 15 | @st.cache_resource |
14 | 16 | def load_link_data(path): |
@@ -181,36 +183,213 @@ def add_children(node_name): |
181 | 183 | tree.append(add_children(node["name"])) |
182 | 184 | return tree |
183 | 185 |
|
184 | | -def parse_topic_folder_name(folder_name, path=None): |
| 186 | +DEBUG = False # prints to terminal |
| 187 | + |
| 188 | +def _debug(msg): |
| 189 | + if DEBUG: |
| 190 | + print(msg, flush=True) |
| 191 | + |
| 192 | +# Matches cluster_for_k=N*.csv (case-insensitive, with optional extra suffix) |
| 193 | +_K_FILE_PAT = re.compile(r'^cluster_for_k=(\d+)(?:[^/]*)\.csv$', re.IGNORECASE) |
| 194 | + |
| 195 | +# Cache per root path |
| 196 | +_COUNTS_CACHE = {} |
| 197 | +_COUNTS_SRC_CACHE = {} |
| 198 | + |
| 199 | +def _list_topic_dirs(root_path: str): |
| 200 | + try: |
| 201 | + dirs = [ |
| 202 | + d for d in os.listdir(root_path) |
| 203 | + if os.path.isdir(os.path.join(root_path, d)) |
| 204 | + ] |
| 205 | + # Heuristic: folders that start with digits are topic dirs (e.g., "3-foo", "12", etc.) |
| 206 | + topic_dirs = [d for d in dirs if re.match(r'^\d+\b', d)] |
| 207 | + return sorted(topic_dirs) |
| 208 | + except FileNotFoundError: |
| 209 | + return [] |
| 210 | + |
| 211 | +def _iter_candidate_csvs(root_path: str): |
| 212 | + """Yield (abs_path, k) for cluster_for_k=*.csv at root and one level below.""" |
| 213 | + if not root_path or not os.path.isdir(root_path): |
| 214 | + _debug(f"[scan] Not a dir: {root_path}") |
| 215 | + return |
| 216 | + |
| 217 | + _debug(f"[scan] Looking for cluster_for_k=*.csv under: {root_path}") |
| 218 | + |
| 219 | + # Top-level files |
| 220 | + try: |
| 221 | + for entry in os.scandir(root_path): |
| 222 | + if entry.is_file(): |
| 223 | + m = _K_FILE_PAT.match(entry.name) |
| 224 | + if m: |
| 225 | + _debug(f"[scan] Found CSV: {entry.name} (k={m.group(1)})") |
| 226 | + yield entry.path, int(m.group(1)) |
| 227 | + except FileNotFoundError: |
| 228 | + pass |
| 229 | + |
| 230 | + # One level down (subfolders only) |
| 231 | + try: |
| 232 | + for entry in os.scandir(root_path): |
| 233 | + if entry.is_dir(): |
| 234 | + for sub in os.scandir(entry.path): |
| 235 | + if sub.is_file(): |
| 236 | + m = _K_FILE_PAT.match(sub.name) |
| 237 | + if m: |
| 238 | + _debug(f"[scan] Found CSV (subdir): {entry.name}/{sub.name} (k={m.group(1)})") |
| 239 | + yield sub.path, int(m.group(1)) |
| 240 | + except FileNotFoundError: |
| 241 | + pass |
| 242 | + |
| 243 | +def _norm_cluster_key(val): |
| 244 | + """Normalize cluster key to a stringified integer if possible (e.g., '3', 3, '3.0', 'cluster_3').""" |
| 245 | + if val is None: |
| 246 | + return None |
| 247 | + if isinstance(val, (int, float)): |
| 248 | + try: |
| 249 | + return str(int(val)) |
| 250 | + except Exception: |
| 251 | + return str(val).strip() |
| 252 | + s = str(val).strip() |
| 253 | + m = re.search(r'(\d+)', s) |
| 254 | + return m.group(1) if m else s # fallback to raw string |
| 255 | + |
| 256 | +def _read_counts_map_from_csv(csv_path: str): |
| 257 | + """Build counts map {cluster_id(str): count(int)} using the 'cluster' column (case-insensitive).""" |
| 258 | + _debug(f"[counts] Reading CSV: {csv_path}") |
| 259 | + counts = defaultdict(int) |
| 260 | + try: |
| 261 | + with open(csv_path, "r", newline="", encoding="utf-8") as f: |
| 262 | + reader = csv.DictReader(f) |
| 263 | + if not reader.fieldnames: |
| 264 | + _debug("[counts] No header/fieldnames found.") |
| 265 | + return {} |
| 266 | + |
| 267 | + # find 'cluster' column case-insensitively |
| 268 | + cluster_col = None |
| 269 | + for name in reader.fieldnames: |
| 270 | + if name and name.strip().lower() == "cluster": |
| 271 | + cluster_col = name |
| 272 | + break |
| 273 | + |
| 274 | + if not cluster_col: |
| 275 | + _debug(f"[counts] 'cluster' column not found in {reader.fieldnames}") |
| 276 | + return {} |
| 277 | + |
| 278 | + row_total = 0 |
| 279 | + for row in reader: |
| 280 | + row_total += 1 |
| 281 | + key = _norm_cluster_key(row.get(cluster_col)) |
| 282 | + if key is not None and key != "": |
| 283 | + counts[key] += 1 |
| 284 | + |
| 285 | + _debug(f"[counts] Total rows read: {row_total}. Unique clusters: {len(counts)}") |
| 286 | + except Exception as e: |
| 287 | + _debug(f"[counts] Error reading {csv_path}: {e}") |
| 288 | + return {} |
| 289 | + return dict(counts) |
| 290 | + |
| 291 | +def _choose_best_cluster_csv(root_path: str): |
| 292 | + """Choose the best cluster_for_k=*.csv. Prefer k == number of topic dirs; else largest k.""" |
| 293 | + candidates = list(_iter_candidate_csvs(root_path)) |
| 294 | + if not candidates: |
| 295 | + _debug("[choose] No cluster_for_k=*.csv candidates found.") |
| 296 | + return None |
| 297 | + |
| 298 | + topic_dirs = _list_topic_dirs(root_path) |
| 299 | + k_target = len(topic_dirs) |
| 300 | + _debug(f"[choose] Topic dirs detected: {k_target}") |
| 301 | + |
| 302 | + # Prefer exact match to number of topic dirs |
| 303 | + exact = [p for p in candidates if p[1] == k_target and k_target > 0] |
| 304 | + if exact: |
| 305 | + # If multiple with same k, pick the shortest path (heuristic) |
| 306 | + chosen = sorted(exact, key=lambda x: (len(x[0]), x[0]))[0] |
| 307 | + _debug(f"[choose] Using exact k match: {os.path.basename(chosen[0])} (k={chosen[1]})") |
| 308 | + return chosen[0] |
| 309 | + |
| 310 | + # Else pick largest k |
| 311 | + candidates.sort(key=lambda x: x[1], reverse=True) |
| 312 | + chosen = candidates[0] |
| 313 | + _debug(f"[choose] Using largest k: {os.path.basename(chosen[0])} (k={chosen[1]})") |
| 314 | + return chosen[0] |
| 315 | + |
| 316 | +def _get_counts_map(root_path: str): |
| 317 | + """Get (and cache) the counts map for this root path.""" |
| 318 | + if root_path in _COUNTS_CACHE: |
| 319 | + return _COUNTS_CACHE[root_path], _COUNTS_SRC_CACHE.get(root_path) |
| 320 | + |
| 321 | + csv_path = _choose_best_cluster_csv(root_path) |
| 322 | + if not csv_path: |
| 323 | + _COUNTS_CACHE[root_path] = {} |
| 324 | + _COUNTS_SRC_CACHE[root_path] = None |
| 325 | + return {}, None |
| 326 | + |
| 327 | + counts_map = _read_counts_map_from_csv(csv_path) |
| 328 | + _COUNTS_CACHE[root_path] = counts_map |
| 329 | + _COUNTS_SRC_CACHE[root_path] = csv_path |
| 330 | + return counts_map, csv_path |
| 331 | + |
| 332 | +def parse_topic_folder_name(folder_name: str, path: str | None = None): |
185 | 333 | """ |
186 | 334 | Extracts topic number, label, and document count from folder name. |
187 | | -
|
188 | | - Expected format: topic_number-label_with_underscores-documents_count-documents |
189 | | - Example: '3-label_of_the_topic_9-documents' -> ('3', 'Label of the topic', '9') |
190 | | -
|
191 | | - Args: |
192 | | - folder_name (str): Folder name formatted as: <topic_number>-<label_with_underscores>_<document_count>-documents |
193 | | -
|
194 | | - Returns: |
195 | | - tuple: (str, str, str) -> (topic_number, cleaned_label, document_count) |
196 | | - or (None, None, None) if parsing fails. |
| 335 | + If the count isn't in the name, read it from cluster_for_k=N.csv at the root (covers all clusters), |
| 336 | + using the 'cluster' column to count rows per cluster ID. |
197 | 337 | """ |
198 | | - match = re.match(r'(\d+)-([^-]+)_(\d+)-documents', folder_name) |
| 338 | + _debug(f"\n[parse] Folder: {folder_name}") |
| 339 | + match = re.match(r'^(\d+)-([^-]+?)(?:[_-](\d+)-documents)?$', folder_name) |
199 | 340 | if match: |
200 | | - topic_number = match.group(1) # Extract topic number |
201 | | - label = match.group(2).replace("_", " ").strip() # Replace underscores with spaces |
202 | | - document_count = match.group(3) # Extract document count |
| 341 | + topic_number = match.group(1) |
| 342 | + label = match.group(2).replace("_", " ").strip() |
| 343 | + document_count = match.group(3) |
| 344 | + _debug(f"[parse] Parsed -> number={topic_number}, label='{label}', name_count={document_count}") |
| 345 | + |
| 346 | + if document_count is None and path: |
| 347 | + counts_map, csv_src = _get_counts_map(path) |
| 348 | + key = _norm_cluster_key(topic_number) |
| 349 | + inferred = counts_map.get(key) |
| 350 | + if inferred is not None: |
| 351 | + document_count = str(inferred) |
| 352 | + _debug(f"[parse] Inferred from CSV ({os.path.basename(csv_src) if csv_src else 'n/a'}): {document_count}") |
| 353 | + else: |
| 354 | + _debug(f"[parse] No count for cluster={key} in CSV.") |
| 355 | + document_count = "Unknown" |
| 356 | + |
| 357 | + if document_count is None: |
| 358 | + document_count = "Unknown" |
| 359 | + |
| 360 | + _debug(f"[parse] Result -> ({topic_number}, '{label}', {document_count})") |
203 | 361 | return topic_number, label, document_count |
| 362 | + |
204 | 363 | elif folder_name.isdigit(): |
205 | | - labels = load_csv_file_items(path, suffix="cluster_summaries", ends=False, column="label") |
206 | | - if path is None or labels is None: |
207 | | - return folder_name, "Topic", "Unknown" |
208 | | - else: |
209 | | - return folder_name, labels[int(folder_name)], "Unknown" |
| 364 | + topic_number = folder_name |
| 365 | + # Try label mapping (your existing helper) |
| 366 | + label = "Topic" |
| 367 | + try: |
| 368 | + labels = load_csv_file_items(path, suffix="cluster_summaries", ends=False, column="label") |
| 369 | + if path is not None and labels is not None: |
| 370 | + label = labels[int(folder_name)] |
| 371 | + except Exception as e: |
| 372 | + _debug(f"[label] Could not map label: {e}") |
| 373 | + |
| 374 | + document_count = "Unknown" |
| 375 | + if path: |
| 376 | + counts_map, csv_src = _get_counts_map(path) |
| 377 | + key = _norm_cluster_key(topic_number) |
| 378 | + inferred = counts_map.get(key) |
| 379 | + if inferred is not None: |
| 380 | + document_count = str(inferred) |
| 381 | + _debug(f"[parse] Inferred (numeric folder) from CSV ({os.path.basename(csv_src) if csv_src else 'n/a'}): {document_count}") |
| 382 | + else: |
| 383 | + _debug(f"[parse] No count for cluster={key} in CSV (numeric folder).") |
| 384 | + |
| 385 | + _debug(f"[parse] Numeric folder -> ({topic_number}, '{label}', {document_count})") |
| 386 | + return topic_number, label, document_count |
210 | 387 |
|
| 388 | + _debug("[parse] Unrecognized folder name format.") |
211 | 389 | return None, None, None |
212 | 390 |
|
213 | 391 |
|
| 392 | + |
214 | 393 | def map_folder_to_logical_name(folder_name: str, root_name) -> str: |
215 | 394 | """ |
216 | 395 | Convert a physical folder name (like '*_0', '*_1_2', or '0_1') |
|
0 commit comments