Skip to content

Commit 27bf724

Browse files
committed
x
Signed-off-by: SumanthRH <[email protected]>
1 parent 6fdca5c commit 27bf724

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

recipes/sky-t1-preview/preprocess.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import json
22

3+
import pyarrow as pa
4+
from ray.data import Schema
5+
36

47
class APPSPreprocessor:
58
WITH_FN_NAME_TEMPLATE = "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}" # noqa: E501
@@ -68,3 +71,17 @@ def __call__(self, row):
6871
prompt = row["problem"]
6972
_input = self.TEMPLATE.format(prompt=prompt)
7073
return {**row, "user_input": _input}
74+
75+
76+
def taco_coerce_types(row, schema: Schema):
77+
for key, schema_type in zip(schema.names, schema.types):
78+
value = pa.array([row[key]])
79+
if value.type != schema_type:
80+
if schema_type == pa.string():
81+
try:
82+
row[key] = str(row[key])
83+
except Exception:
84+
row[key] = ""
85+
elif schema_type == pa.null():
86+
row[key] = None
87+
return row

recipes/sky-t1-preview/recipe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from skythought.evals.scoring.taco import TACOScorer
1919

2020
from .postprocess import convert_to_sharegpt_format
21-
from .preprocess import APPSPreprocessor, NUMINAPreprocessor, TACOPreprocessor
21+
from .preprocess import (
22+
APPSPreprocessor,
23+
NUMINAPreprocessor,
24+
TACOPreprocessor,
25+
taco_coerce_types,
26+
)
2227
from .prompts import CONVERT_PROMPT, CONVERT_PROMPT_EXAMPLE
2328

2429
parser = argparse.ArgumentParser()
@@ -38,6 +43,9 @@
3843
# convert all to ray dataset
3944
apps_ds = ray.data.from_huggingface(apps_ds)
4045
taco_ds_medium = ray.data.from_huggingface(taco_ds_medium)
46+
taco_ds_medium = taco_ds_medium.map(
47+
taco_coerce_types, fn_args=(taco_ds_medium.schema(),)
48+
)
4149
numina_ds = ray.data.from_huggingface(numina_ds)
4250

4351

@@ -77,7 +85,7 @@
7785
)
7886
scorers = [
7987
APPSScorer(response_column="formatted_response"),
80-
TACOScorer(response_column="formatted_response"),
88+
TACOScorer(response_column="formatted_response", backend="ray"),
8189
numina_scorer,
8290
numina_scorer,
8391
numina_scorer,
@@ -168,4 +176,6 @@
168176

169177
# 6. Save datasets
170178
dir_name = f"sky-t1-preview-{i}_parquet"
179+
datasets[i] = datasets[i].materialize()
180+
# breakpoint()
171181
datasets[i].write_parquet(os.path.abspath(dir_name))

skythought/evals/scoring/apps/apps.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import json
3-
from typing import Any, Dict
3+
from typing import Any, Dict, Literal
44

55
import numpy as np
66
import ray
@@ -19,11 +19,13 @@ def __init__(
1919
response_column="response",
2020
answer_column="solutions",
2121
input_column="input_output",
22+
backend: Literal["mp", "ray"] = "ray",
2223
) -> None:
2324
super().__init__()
2425
self.response_column = response_column
2526
self.answer_column = answer_column
2627
self.input_column = input_column
28+
self.backend = backend
2729

2830
def score(self, row: Dict[str, Any]):
2931
TIMEOUT = 10

0 commit comments

Comments
 (0)