Skip to content

Commit ac44bf2

Browse files
committed
address feedback
Signed-off-by: Christian Hagemeier <chagem@amazon.com>
1 parent caefc99 commit ac44bf2

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

benchmark/analysis-scripts/autogroup.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import csv
1414
import warnings
1515
import statistics
16-
import sys
16+
1717

1818
from tabulate import tabulate
1919
from collections import defaultdict
@@ -157,8 +157,7 @@ def find_varying_parameters(all_configs: List[Dict[str, Any]]) -> Set[str]:
157157
def find_multirun_dir(index: int = 0) -> str:
158158
"""Find the Nth latest directory in multirun (0=most recent, 1=previous, etc.)"""
159159
if not Path('multirun').exists():
160-
warnings.warn("multirun directory not found")
161-
sys.exit(1)
160+
raise FileNotFoundError("multirun directory not found")
162161

163162
# This ensures that alphabetical sorting will correctly find latest
164163
sorted_subdirs = sorted(
@@ -168,9 +167,12 @@ def find_multirun_dir(index: int = 0) -> str:
168167
)
169168

170169
if not sorted_subdirs:
171-
warnings.warn("No experiment directories found in multirun")
172-
sys.exit(1)
173-
return sorted_subdirs[index][1]
170+
raise FileNotFoundError("No experiment directories found in multirun")
171+
172+
try:
173+
return sorted_subdirs[index][1]
174+
except IndexError:
175+
raise IndexError(f"Index {index} is out of range. Only {len(sorted_subdirs)} experiment directories found.")
174176

175177

176178
def main() -> None:
@@ -181,7 +183,9 @@ def main() -> None:
181183

182184
parser.add_argument('--csv-output', help='Optional CSV file to write the results to')
183185
parser.add_argument(
184-
'--runs', choices=['tri', 'all'], help='Show run numbers in results (tri=min/median/max, all=all runs)'
186+
'--runs',
187+
choices=['rep', 'all'],
188+
help='Show run numbers in results (rep(resentative)=min/median/max, all=all runs)',
185189
)
186190
args = parser.parse_args()
187191

@@ -192,8 +196,8 @@ def main() -> None:
192196
try:
193197
base_dir = find_multirun_dir()
194198
print(f"Using inferred base directory: {base_dir}")
195-
except IndexError:
196-
print("Invalid argument, cannot find latest directory")
199+
except (IndexError, FileNotFoundError) as e:
200+
print(f"Error: {e}")
197201
return
198202

199203
# List to store all results
@@ -233,33 +237,29 @@ def main() -> None:
233237

234238
results_rows = []
235239
for config_key, throughput_data in grouped_results.items():
236-
throughputs = [t for t, _ in throughput_data]
237-
run_numbers = [r for _, r in throughput_data]
240+
throughputs, run_numbers = zip(*throughput_data)
238241

239242
row = []
240243
for _, value in config_key:
241244
row.append(value)
242245

243246
# Add run numbers column if requested
244247
if args.runs:
245-
if args.runs == "tri":
248+
sorted_by_throughput = sorted(zip(throughputs, run_numbers), reverse=True)
249+
if args.runs == "rep":
246250
# Find min, max, and median run numbers based on throughput
247-
sorted_by_throughput = sorted(zip(throughputs, run_numbers))
248-
min_run = sorted_by_throughput[0][1]
249-
max_run = sorted_by_throughput[-1][1]
251+
min_run = sorted_by_throughput[-1][1]
252+
max_run = sorted_by_throughput[0][1]
250253
median_idx = len(sorted_by_throughput) // 2
251254
median_run = sorted_by_throughput[median_idx][1]
252255

253256
selected_runs = [max_run, median_run, min_run]
254-
# Remove duplicates while preserving order
255-
unique_runs = []
256-
for run in selected_runs:
257-
if run not in unique_runs:
258-
unique_runs.append(run)
257+
# Remove duplicates while preserving order using dict.fromkeys()
258+
# (works in python > 3.7)
259+
unique_runs = list(dict.fromkeys(selected_runs))
259260

260261
row.append(",".join(unique_runs))
261262
else:
262-
sorted_by_throughput = sorted(zip(throughputs, run_numbers), reverse=True)
263263
all_runs = [r for _, r in sorted_by_throughput]
264264
row.append(",".join(all_runs))
265265

@@ -300,7 +300,7 @@ def sort_key(row: List[str]) -> List[Union[int, float, str]]:
300300
# Display results
301301
if args.runs == "all":
302302
print("\nResults Summary (with all run numbers):")
303-
elif args.runs == "tri":
303+
elif args.runs == "rep":
304304
print("\nResults Summary (with representative run numbers):")
305305
else:
306306
print("\nResults Summary:")

0 commit comments

Comments
 (0)