11from pathlib import Path
2- from typing import Dict , Sequence , Any , Tuple
2+ from typing import Dict , Any
33from .base_block import AnimalBlock
44from .data_bundle import DataBundle , SAVE_DIR_BUNDLE_KEY
55
66from ...post_processing import ArcticFox
77from ...factorization import HNMFk
88
9+
910class ArticFoxBlock (AnimalBlock ):
11+ """
12+ Block wrapper for the ArcticFox post-process/label/stats pipeline.
13+ Use call_settings['steps'] to run any subset: ["post"], ["label"], ["stats"],
14+ or combinations like ["post","label"], ["label","stats"], ["post","stats"], ["post","label","stats"].
15+ If steps is None, legacy behavior uses the label_clusters/generate_stats booleans.
16+ """
1017
11- CANONICAL_NEEDS = ("df" , ' vocabulary' , "model_path" , )
18+ CANONICAL_NEEDS = ("df" , " vocabulary" , "model_path" )
1219
1320 def __init__ (
1421 self ,
@@ -21,62 +28,64 @@ def __init__(
2128 call_settings : Dict [str , Any ] = None ,
2229 ** kw ,
2330 ) -> None :
24-
31+
2532 self .col = col
2633 default_init = {
27- ' clean_cols_name' : self .col ,
28- ' embedding_model' : "SCINCL" ,
34+ " clean_cols_name" : self .col ,
35+ " embedding_model" : "SCINCL" ,
2936 }
3037 default_call = {
31- 'ollama_model' : "llama3.2:3b-instruct-fp16" , # Language model used for semantic label generation
32- 'label_clusters' : True , # Enable automatic labeling of clusters
33- 'generate_stats' : True , # Generate cluster-level statistics
34- 'process_parents' : True , # Propagate labels or stats upward through the hierarchy
35- 'skip_completed' : True , # Skip processing of nodes already labeled/stored
36- 'label_criteria' : { # Rules to filter generated labels
37- "minimum words" : 2 ,
38- "maximum words" : 6
39- },
40- 'label_info' : { # Additional metadata to associate with generated labels
41- "source" : "Science"
42- },
43- 'number_of_labels' : 5 # Number of candidate labels to generate per node
38+ "ollama_model" : "llama3.2:3b-instruct-fp16" , # Language model used for semantic label generation
39+ "label_clusters" : True , # Back-compat: used when steps is None
40+ "generate_stats" : True , # Back-compat: used when steps is None
41+ "process_parents" : True ,
42+ "skip_completed" : True ,
43+ "label_criteria" : {"minimum words" : 2 , "maximum words" : 6 },
44+ "label_info" : {"source" : "Science" },
45+ "number_of_labels" : 5 ,
46+ # NEW: choose subset explicitly; None keeps legacy boolean behavior
47+ # Examples: ["post"], ["label"], ["stats"], ["post","label"], ["label","stats"], ["post","stats"], ["post","label","stats"]
48+ "steps" : None ,
4449 }
4550
4651 super ().__init__ (
47- needs = needs ,
48- provides = provides ,
52+ needs = needs ,
53+ provides = provides ,
4954 init_settings = self ._merge (default_init , init_settings ),
5055 call_settings = self ._merge (default_call , call_settings ),
5156 tag = tag ,
5257 ** kw ,
5358 )
5459
55-
5660 def run (self , bundle : DataBundle ) -> None :
61+ # Resolve inputs
5762 df = self .load_path (bundle [self .needs [0 ]])
5863 vocabulary = self .load_path (bundle [self .needs [1 ]])
59-
6064 raw_model_path = str (bundle [self .needs [2 ]])
61- # Try to resolve to an absolute path for traceability; fall back to the raw string.
65+
6266 try :
6367 resolved_model_path = str (Path (raw_model_path ).expanduser ().resolve ())
6468 except Exception :
6569 resolved_model_path = raw_model_path
6670
71+ # Load HNMFk model
6772 model = HNMFk (experiment_name = raw_model_path )
6873 model .load_model ()
6974
75+ # Run selected steps (order enforced inside ArcticFox)
7076 pipeline = ArcticFox (model = model , ** self .init_settings )
71- pipeline .run_full_pipeline (data_df = df , vocab = vocabulary , ** self .call_settings )
77+ pipeline .run_full_pipeline (
78+ data_df = df ,
79+ vocab = vocabulary ,
80+ ** self .call_settings
81+ )
7282
83+ # Write a lightweight status checkpoint
7384 status_value = "Done"
74-
7585 if SAVE_DIR_BUNDLE_KEY in bundle :
7686 out_dir = Path (bundle [SAVE_DIR_BUNDLE_KEY ]) / self .tag
7787 out_dir .mkdir (parents = True , exist_ok = True )
7888 status_file = out_dir / "status.txt"
79- # Include model path info in the checkpointed status file
8089 status_file .write_text (
8190 f"status: { status_value } \n "
8291 f"model_path: { raw_model_path } \n "
@@ -86,7 +95,3 @@ def run(self, bundle: DataBundle) -> None:
8695 self .register_checkpoint (self .provides [0 ], status_file )
8796
8897 bundle [f"{ self .tag } .{ self .provides [0 ]} " ] = status_value
89-
90-
91-
92-
0 commit comments