Skip to content

Commit 9e3203d

Browse files
committed
improve recipe
Signed-off-by: SumanthRH <[email protected]>
1 parent 6a8bb2f commit 9e3203d

File tree

1 file changed

+61
-23
lines changed

1 file changed

+61
-23
lines changed

recipes/sky-t1-preview/recipe.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import argparse
66
import os
7+
from pathlib import Path
78

89
import datasets
910
import ray
@@ -28,40 +29,66 @@
2829

2930
parser = argparse.ArgumentParser()
3031
parser.add_argument("--as-test", action="store_true")
32+
parser.add_argument("--save-dir", type=str, required=True, help="Output directory")
3133
args = parser.parse_args()
34+
args.save_dir = Path(args.save_dir)
3235

3336
SYSTEM_PROMPT = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." # noqa: E501
3437
MAX_TOKENS = 16384
38+
# We explicitly set the target number of blocks to help tune performance.
39+
# For materialized datasets, the number of blocks determined by ray data can be small,
40+
# especially for a multi-stage pipeline like the one here.
41+
TARGET_NUM_ROWS_PER_BLOCK = 100
42+
43+
# Enable more detailed logging of tasks per actor
44+
ray.init(runtime_env={"env_vars": {"RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING": 1}})
45+
3546
# 1. Load datasets
36-
apps_ds = datasets.load_dataset("codeparrot/apps", split="test", trust_remote_code=True)
47+
apps_ds = datasets.load_dataset(
48+
"codeparrot/apps", split="test", trust_remote_code=True
49+
) # 10K
3750
taco_ds_medium = datasets.load_dataset(
38-
"BAAI/TACO", split="test", name="MEDIUM", trust_remote_code=True
39-
)
51+
"BAAI/TACO", split="train", name="MEDIUM", trust_remote_code=True
52+
) # 3244
53+
taco_ds_test = datasets.load_dataset(
54+
"BAAI/TACO", split="test", name="ALL", trust_remote_code=True
55+
) # 1000
4056
numina_ds = datasets.load_dataset(
4157
"AI-MO/NuminaMath-CoT", split="train", trust_remote_code=True
4258
)
4359

4460
# convert all to ray dataset
45-
apps_ds = ray.data.from_huggingface(apps_ds)
46-
taco_ds_medium = ray.data.from_huggingface(taco_ds_medium)
61+
apps_ds = ray.data.from_huggingface(apps_ds).repartition(
62+
num_blocks=None, target_num_rows_per_block=TARGET_NUM_ROWS_PER_BLOCK
63+
)
64+
taco_ds_medium = ray.data.from_huggingface(
65+
taco_ds_medium,
66+
).repartition(num_blocks=None, target_num_rows_per_block=TARGET_NUM_ROWS_PER_BLOCK)
4767
taco_ds_medium = taco_ds_medium.map(
4868
taco_coerce_types, fn_args=(taco_ds_medium.schema(),)
4969
)
50-
numina_ds = ray.data.from_huggingface(numina_ds)
70+
taco_ds_test = ray.data.from_huggingface(
71+
taco_ds_test,
72+
).repartition(num_blocks=None, target_num_rows_per_block=TARGET_NUM_ROWS_PER_BLOCK)
73+
taco_ds_test = taco_ds_test.map(taco_coerce_types, fn_args=(taco_ds_test.schema(),))
74+
numina_ds = ray.data.from_huggingface(
75+
numina_ds,
76+
).repartition(num_blocks=None, target_num_rows_per_block=TARGET_NUM_ROWS_PER_BLOCK)
5177

5278

5379
# get subsets from numina based on the source column
54-
numina_ds_amc_aime = numina_ds.filter(lambda x: x["source"] == "amc_aime")
80+
numina_ds_amc_aime = numina_ds.filter(lambda x: x["source"] == "amc_aime") # 4070
5581
numina_ds_olympiads = numina_ds.filter(lambda x: x["source"] == "olympiads").limit(
5682
20000
57-
)
58-
numina_ds_math = numina_ds.filter(lambda x: x["source"] == "math")
83+
) # 20k
84+
numina_ds_math = numina_ds.filter(lambda x: x["source"] == "math") # 7477
5985

6086

6187
if args.as_test:
62-
num_samples = 100
88+
num_samples = 5000
6389
apps_ds = apps_ds.limit(num_samples)
6490
taco_ds_medium = taco_ds_medium.limit(num_samples)
91+
taco_ds_test = taco_ds_test.limit(num_samples)
6592
numina_ds_amc_aime = numina_ds_amc_aime.limit(num_samples)
6693
numina_ds_olympiads = numina_ds_olympiads.limit(num_samples)
6794
numina_ds_math = numina_ds_math.limit(num_samples)
@@ -70,6 +97,7 @@
7097
datasets = [
7198
apps_ds,
7299
taco_ds_medium,
100+
taco_ds_test,
73101
numina_ds_amc_aime,
74102
numina_ds_olympiads,
75103
numina_ds_math,
@@ -79,12 +107,20 @@
79107
preprocessors = [
80108
APPSPreprocessor,
81109
TACOPreprocessor,
110+
TACOPreprocessor,
82111
NUMINAPreprocessor,
83112
NUMINAPreprocessor,
84113
NUMINAPreprocessor,
85114
]
86115

87-
dataset_names = ["apps", "taco", "numina_amc_aime", "numina_math", "numina_olympiads"]
116+
dataset_names = [
117+
"apps",
118+
"taco_train",
119+
"taco_test",
120+
"numina_amc_aime",
121+
"numina_math",
122+
"numina_olympiads",
123+
]
88124
scorer_configs = [
89125
dict(
90126
cls=APPSScorer, fn_constructor_kwargs=dict(response_column="formatted_response")
@@ -93,6 +129,10 @@
93129
cls=TACOScorer,
94130
fn_constructor_kwargs=dict(response_column="formatted_response", backend="ray"),
95131
),
132+
dict(
133+
cls=TACOScorer,
134+
fn_constructor_kwargs=dict(response_column="formatted_response", backend="ray"),
135+
),
96136
dict(
97137
cls=MathEqualScorer,
98138
fn_constructor_kwargs=dict(
@@ -114,8 +154,6 @@
114154
]
115155

116156
for i, ds in enumerate(datasets):
117-
if i < 1:
118-
continue
119157
# 1. Preprocess and get model prompts
120158
preprocess_cls = preprocessors[i]
121159
datasets[i] = ds.map(
@@ -126,15 +164,16 @@
126164
# 2. Get model responses
127165

128166
config = vLLMEngineProcessorConfig(
129-
# model="Qwen/QwQ-32B-Preview",
130-
model="Qwen/Qwen2-0.5B-Instruct",
167+
model="Qwen/QwQ-32B-Preview",
168+
# model="Qwen/Qwen2-0.5B-Instruct",
131169
engine_kwargs=dict(
132170
enable_prefix_caching=True,
133171
enable_chunked_prefill=True,
134-
max_num_batched_tokens=16384,
172+
max_num_batched_tokens=4096,
173+
tensor_parallel_size=4,
135174
),
136175
concurrency=2,
137-
batch_size=20,
176+
batch_size=128,
138177
)
139178

140179
processor = build_llm_processor(
@@ -145,9 +184,8 @@
145184
{"role": "user", "content": row["user_input"]},
146185
],
147186
sampling_params=dict(
148-
temperature=0.3,
187+
temperature=0,
149188
max_tokens=MAX_TOKENS,
150-
detokenize=False,
151189
),
152190
),
153191
postprocess=lambda row: dict(
@@ -166,7 +204,7 @@
166204
# number of processors to run in parallel
167205
# Each handles a batch of requests
168206
concurrency=1,
169-
batch_size=64,
207+
batch_size=16,
170208
)
171209
# define the reformatter
172210
reformatter = build_llm_processor(
@@ -219,9 +257,9 @@
219257
)
220258

221259
# 6. Save datasets
222-
dir_name = f"data/sky-t1-preview-{dataset_names[i]}"
223-
datasets[i] = datasets[i].materialize()
224-
datasets[i].write_json(os.path.abspath(dir_name))
260+
dir_name = args.save_dir / f"sky-t1-preview-{dataset_names[i]}"
261+
# use absolute path while saving with ray data
262+
datasets[i].write_json(str(dir_name.expanduser().absolute()))
225263

226264

227265
# 7. Union

0 commit comments

Comments
 (0)