Skip to content

Commit 1d6baf7

Browse files
committed
Add fast fail ovveride for benchmarks
Signed-off-by: Tadiwa Magwenzi <[email protected]>
1 parent 57e1d23 commit 1d6baf7

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

benchmark/conf/hydra/sweeper/multi.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
defaults:
33
- base
44

5-
_target_: hydra_plugins.smart_sweeper.smart_benchmark_sweeper.SmartBenchmarkSweeper
5+
_target_: hydra_plugins.smart_sweeper.smart_benchmark_sweeper.SmartBenchmarkSweeper
6+
fail_fast: false

benchmark/hydra_plugins/smart_sweeper/smart_benchmark_sweeper.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@ class SmartBenchmarkSweeperConf:
2121
_target_: str = "hydra_plugins.smart_sweeper.smart_benchmark_sweeper.SmartBenchmarkSweeper"
2222
max_batch_size: Optional[int] = None
2323
params: Optional[Dict[str, str]] = None
24+
fail_fast: bool = False
2425

2526

2627
ConfigStore.instance().store(group="hydra/sweeper", name="smart_benchmark", node=SmartBenchmarkSweeperConf)
2728

2829

2930
class SmartBenchmarkSweeper(Sweeper):
30-
def __init__(self, max_batch_size: Optional[int] = None, params: Optional[Dict[str, str]] = None):
31+
def __init__(
32+
self, max_batch_size: Optional[int] = None, params: Optional[Dict[str, str]] = None, fail_fast: bool = False
33+
):
3134
self.max_batch_size = max_batch_size
3235
self.params = params or {}
36+
self.fail_fast = fail_fast
3337
self.config: Optional[DictConfig] = None
3438
self.launcher: Optional[Launcher] = None
3539
self.hydra_context: Optional[HydraContext] = None
@@ -86,8 +90,19 @@ def sweep(self, arguments: List[str]) -> Any:
8690
initial_job_idx = 0
8791
if all_combinations:
8892
self.validate_batch_is_legal(all_combinations)
89-
results = self.launcher.launch(all_combinations, initial_job_idx=initial_job_idx)
90-
returns.append(results)
93+
94+
batch_size = 1 if self.fail_fast else len(all_combinations)
95+
96+
for i in range(0, len(all_combinations), batch_size):
97+
batch = all_combinations[i : i + batch_size]
98+
results = self.launcher.launch(batch, initial_job_idx=initial_job_idx)
99+
100+
if self.fail_fast:
101+
for r in results:
102+
_ = r.return_value
103+
104+
initial_job_idx += len(batch)
105+
returns.append(results)
91106

92107
return returns
93108

0 commit comments

Comments
 (0)