Skip to content

Commit fab4d13

Browse files
committed
x
Signed-off-by: SumanthRH <[email protected]>
1 parent 834dbf7 commit fab4d13

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed

skythought/evals/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,14 @@ def score(
530530
case_sensitive=False,
531531
),
532532
],
533+
idx: Annotated[
534+
str,
535+
typer.Option(
536+
...,
537+
help="Unique index of the sample in the results JSON to re-score."
538+
"If provided, only the scores for this sample are computed/re-computed. ",
539+
),
540+
] = None,
533541
):
534542
if not os.path.exists(run_dir):
535543
raise ValueError(f"Run directory {run_dir} does not exist.")
@@ -556,7 +564,7 @@ def score(
556564

557565
run_summary = SummaryResults(**run_summary)
558566

559-
score_results(handler, run_dir, run_summary)
567+
score_results(handler, run_dir, run_summary, idx)
560568

561569

562570
def main():

skythought/evals/inference_and_check.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def generate_responses_for_dataset(
294294
def score_responses(
295295
handler: TaskHandler,
296296
id_to_results: Dict[str, Dict[str, Any]],
297+
*,
297298
max_workers: int = 32,
298299
) -> Tuple[float, Dict[str, List[int]], int]:
299300
"""Computes correctness for model responses for the given task
@@ -341,7 +342,7 @@ def score_responses(
341342
# TODO (sumanthrh): this can be improved
342343
if unique_id not in id_to_scores:
343344
id_to_scores[unique_id] = [0 for _ in range(N)]
344-
id_to_scores[unique_id][i] = new_response_entry["correctness"]
345+
id_to_scores[unique_id][i] = int(new_response_entry["correctness"])
345346

346347
total_correct += new_response_entry["correctness"]
347348
total_finish += 1
@@ -350,6 +351,40 @@ def score_responses(
350351
return accuracy, id_to_scores, total_finish
351352

352353

354+
def score_responses_for_idx(
355+
handler: TaskHandler,
356+
id_to_results: Dict[str, Dict[str, Any]],
357+
*,
358+
idx: str,
359+
) -> List[int]:
360+
"""Computes correctness for model responses for the given task for the unique index `idx`.
361+
362+
The 'id_to_results' dictionary is assumed to be a mapping between problem ID -> { responses: [...], ... },
363+
This is updated in-place.
364+
365+
Returns:
366+
- list of scores
367+
"""
368+
if not id_to_results:
369+
return []
370+
371+
# Figure out how many generations per problem
372+
N = len(next(iter(id_to_results.values()))["responses"])
373+
record = id_to_results[idx]
374+
scores = []
375+
for i in range(N):
376+
content = record["responses"][i]["content"]
377+
response_entry = handler.update_results(record, content)
378+
379+
# Update correctness and reason in the original results dict
380+
id_to_results[idx]["responses"][i]["correctness"] = response_entry[
381+
"correctness"
382+
]
383+
id_to_results[idx]["responses"][i]["reason"] = response_entry["reason"]
384+
scores.append(response_entry["correctness"])
385+
return scores
386+
387+
353388
def generate_and_score(
354389
handler: TaskHandler,
355390
model_config: ModelConfig,
@@ -480,17 +515,29 @@ def generate_and_save(
480515

481516

482517
def score_results(
483-
handler: TaskHandler, run_dir: Path, run_summary: SummaryResults
518+
handler: TaskHandler,
519+
run_dir: Path,
520+
run_summary: SummaryResults,
521+
idx: Optional[str] = None,
484522
) -> None:
485523
# load existing results
486524
result_file = run_dir / RESULTS_FILENAME
487525
summary_file = run_dir / SUMMARY_FILENAME
488526
id_to_results = load_existing_results(result_file)
489527
logger.info(f"Loaded {len(id_to_results)} existing results for scoring.")
490528

491-
accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)
492-
493-
logger.info(f"Accuracy: {accuracy}")
529+
if not idx:
530+
accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)
531+
else:
532+
N = len(next(iter(id_to_results.values()))["responses"])
533+
score_responses_for_idx(handler, id_to_results, idx=idx)
534+
id_to_scores = {
535+
index: [
536+
id_to_results[index]["responses"][i]["correctness"] for i in range(N)
537+
]
538+
for index in id_to_results
539+
}
540+
accuracy = sum(map(sum, id_to_scores.values())) / len(id_to_scores) * N
494541

495542
sample_count = 0
496543
if id_to_results:
@@ -501,7 +548,9 @@ def score_results(
501548

502549
run_summary.accuracy = accuracy
503550
run_summary.pass_at_k = pass_at_k_metrics
551+
552+
logger.info(f"Accuracy: {accuracy}")
504553
save_summary(summary_file, run_summary)
505554

506555
save_results(result_file, id_to_results)
507-
logger.info(f"Re-scored results saved to {result_file}")
556+
logger.info(f"Scored results saved to {result_file}")

skythought/evals/tasks/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def check_correctness(
5858
pass
5959

6060
@abstractmethod
61-
def update_results(self, problem: Dict[str, Any], response: str):
61+
def update_results(self, problem: Dict[str, Any], response: str) -> Dict[str, Any]:
6262
pass
6363

6464
def make_conversations(

skythought/evals/util/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import math
33
from collections import defaultdict
4-
from typing import Any, Dict
4+
from typing import Dict, List
55

66
import numpy as np
77

@@ -17,7 +17,7 @@ def _pass_at_k(n, c, k):
1717
return float(1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))
1818

1919

20-
def pass_at_k(N: int, id_to_scores: Dict[str, Dict[str, Any]]):
20+
def pass_at_k(N: int, id_to_scores: Dict[str, List[int]]):
2121
final_passk_scores = {}
2222
k_to_passk_scores = defaultdict(list) # k -> list of scores
2323
for _, sample_scores in id_to_scores.items():

0 commit comments

Comments
 (0)