@@ -21,15 +21,19 @@ class SmartBenchmarkSweeperConf:
2121 _target_ : str = "hydra_plugins.smart_sweeper.smart_benchmark_sweeper.SmartBenchmarkSweeper"
2222 max_batch_size : Optional [int ] = None
2323 params : Optional [Dict [str , str ]] = None
24+ fail_fast : bool = False
2425
2526
2627ConfigStore .instance ().store (group = "hydra/sweeper" , name = "smart_benchmark" , node = SmartBenchmarkSweeperConf )
2728
2829
2930class SmartBenchmarkSweeper (Sweeper ):
30- def __init__ (self , max_batch_size : Optional [int ] = None , params : Optional [Dict [str , str ]] = None ):
31+ def __init__ (
32+ self , max_batch_size : Optional [int ] = None , params : Optional [Dict [str , str ]] = None , fail_fast : bool = False
33+ ):
3134 self .max_batch_size = max_batch_size
3235 self .params = params or {}
36+ self .fail_fast = fail_fast
3337 self .config : Optional [DictConfig ] = None
3438 self .launcher : Optional [Launcher ] = None
3539 self .hydra_context : Optional [HydraContext ] = None
@@ -86,8 +90,19 @@ def sweep(self, arguments: List[str]) -> Any:
8690 initial_job_idx = 0
8791 if all_combinations :
8892 self .validate_batch_is_legal (all_combinations )
89- results = self .launcher .launch (all_combinations , initial_job_idx = initial_job_idx )
90- returns .append (results )
93+
94+ batch_size = 1 if self .fail_fast else len (all_combinations )
95+
96+ for i in range (0 , len (all_combinations ), batch_size ):
97+ batch = all_combinations [i : i + batch_size ]
98+ results = self .launcher .launch (batch , initial_job_idx = initial_job_idx )
99+
100+ if self .fail_fast :
101+ for r in results :
102+ _ = r .return_value
103+
104+ initial_job_idx += len (batch )
105+ returns .append (results )
91106
92107 return returns
93108
0 commit comments