Skip to content

Commit 5dd5978

Browse files
committed
switch to actor
Signed-off-by: SumanthRH <[email protected]>
1 parent f55909f commit 5dd5978

File tree

2 files changed

+83
-32
lines changed

2 files changed

+83
-32
lines changed

recipes/sky-t1-preview/postprocess.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,24 @@
2121
Now, try to solve the following question through the above guidelines:"
2222

2323

24-
def convert_to_sharegpt_format(row: Dict[str, Any]):
25-
prompt = row["user_input"]
24+
def convert_to_sharegpt_format(row: Dict[str, Any], prompt_column, response_column):
25+
prompt = row[prompt_column]
2626
# accept
2727
# Create the conversation format
2828
conversations = [
2929
{"from": "user", "value": prompt},
3030
{
3131
"from": "assistant",
32-
"value": row["formatted_response"],
32+
"value": row[response_column],
3333
},
3434
]
3535

3636
# Prepare the final structure
3737
cur_data = {
3838
"system": STILL2_SYSTEM_PROMPT,
3939
"conversations": conversations,
40+
# TODO: remove this
41+
**row,
4042
}
4143

4244
return cur_data

recipes/sky-t1-preview/recipe.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@
3333
SYSTEM_PROMPT = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." # noqa: E501
3434
MAX_TOKENS = 16384
3535
# 1. Load datasets
36-
apps_ds = datasets.load_dataset(
37-
"codeparrot/apps",
38-
split="test",
36+
apps_ds = datasets.load_dataset("codeparrot/apps", split="test", trust_remote_code=True)
37+
taco_ds_medium = datasets.load_dataset(
38+
"BAAI/TACO", split="test", name="MEDIUM", trust_remote_code=True
39+
)
40+
numina_ds = datasets.load_dataset(
41+
"AI-MO/NuminaMath-CoT", split="train", trust_remote_code=True
3942
)
40-
taco_ds_medium = datasets.load_dataset("BAAI/TACO", split="test", name="MEDIUM")
41-
numina_ds = datasets.load_dataset("AI-MO/NuminaMath-CoT", split="train")
4243

4344
# convert all to ray dataset
4445
apps_ds = ray.data.from_huggingface(apps_ds)
@@ -75,38 +76,65 @@
7576
]
7677

7778
# these are user-defined simple preprocessing functions to go from entry -> prompt
78-
preprocess_fns = [
79-
APPSPreprocessor(),
80-
TACOPreprocessor(),
81-
NUMINAPreprocessor(),
82-
NUMINAPreprocessor(),
83-
NUMINAPreprocessor(),
79+
preprocessors = [
80+
APPSPreprocessor,
81+
TACOPreprocessor,
82+
NUMINAPreprocessor,
83+
NUMINAPreprocessor,
84+
NUMINAPreprocessor,
8485
]
8586

86-
numina_scorer = MathEqualScorer(
87-
response_column="formatted_response", answer_column="solution"
88-
)
89-
scorers = [
90-
APPSScorer(response_column="formatted_response"),
91-
TACOScorer(response_column="formatted_response", backend="ray"),
92-
numina_scorer,
93-
numina_scorer,
94-
numina_scorer,
95-
]
9687
dataset_names = ["apps", "taco", "numina_amc_aime", "numina_math", "numina_olympiads"]
88+
scorer_configs = [
89+
dict(
90+
cls=APPSScorer, fn_constructor_kwargs=dict(response_column="formatted_response")
91+
),
92+
dict(
93+
cls=TACOScorer,
94+
fn_constructor_kwargs=dict(response_column="formatted_response", backend="ray"),
95+
),
96+
dict(
97+
cls=MathEqualScorer,
98+
fn_constructor_kwargs=dict(
99+
response_column="formatted_response", answer_column="solution"
100+
),
101+
),
102+
dict(
103+
cls=MathEqualScorer,
104+
fn_constructor_kwargs=dict(
105+
response_column="formatted_response", answer_column="solution"
106+
),
107+
),
108+
dict(
109+
cls=MathEqualScorer,
110+
fn_constructor_kwargs=dict(
111+
response_column="formatted_response", answer_column="solution"
112+
),
113+
),
114+
]
115+
97116
for i, ds in enumerate(datasets):
98-
datasets[i] = ds.map(preprocess_fns[i])
117+
if i < 1:
118+
continue
119+
# 1. Preprocess and get model prompts
120+
preprocess_cls = preprocessors[i]
121+
datasets[i] = ds.map(
122+
preprocess_cls,
123+
concurrency=5,
124+
)
125+
126+
# 2. Get model responses
99127

100128
config = vLLMEngineProcessorConfig(
101-
model="Qwen/QwQ-32B-Preview",
102-
# model="Qwen/Qwen2-0.5B-Instruct",
129+
# model="Qwen/QwQ-32B-Preview",
130+
model="Qwen/Qwen2-0.5B-Instruct",
103131
engine_kwargs=dict(
104132
enable_prefix_caching=True,
105133
enable_chunked_prefill=True,
106134
max_num_batched_tokens=16384,
107135
),
108136
concurrency=2,
109-
batch_size=64,
137+
batch_size=20,
110138
)
111139

112140
processor = build_llm_processor(
@@ -118,7 +146,7 @@
118146
],
119147
sampling_params=dict(
120148
temperature=0.3,
121-
max_tokens=20,
149+
max_tokens=MAX_TOKENS,
122150
detokenize=False,
123151
),
124152
),
@@ -130,13 +158,15 @@
130158
datasets[i] = processor(datasets[i])
131159

132160
# 3. Reformat the examples into a structured format
161+
133162
# define a configuration for the reformatter
134163
config = HttpRequestProcessorConfig(
135164
url="https://api.openai.com/v1/chat/completions",
136165
headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"},
137166
# number of processors to run in parallel
138167
# Each handles a batch of requests
139168
concurrency=1,
169+
batch_size=64,
140170
)
141171
# define the reformatter
142172
reformatter = build_llm_processor(
@@ -170,14 +200,33 @@
170200
datasets[i] = reformatter(datasets[i])
171201

172202
# 4. Rejection Sampling based on scoring
173-
datasets[i] = datasets[i].map(scorers[i])
174-
score_column = scorers[i].SCORE_COLUMN
203+
scorer_cls, fn_constructor_kwargs = (
204+
scorer_configs[i]["cls"],
205+
scorer_configs[i]["fn_constructor_kwargs"],
206+
)
207+
datasets[i] = datasets[i].map(
208+
scorer_cls, concurrency=4, fn_constructor_kwargs=fn_constructor_kwargs
209+
)
210+
score_column = scorer_cls.SCORE_COLUMN
175211
datasets[i] = datasets[i].filter(lambda x, sc=score_column: x[sc])
176212

177213
# 5. Convert to ShareGPT format
178-
datasets[i] = datasets[i].map(convert_to_sharegpt_format)
214+
datasets[i] = datasets[i].map(
215+
convert_to_sharegpt_format,
216+
fn_kwargs=dict(
217+
prompt_column="user_input", response_column="formatted_response"
218+
),
219+
)
179220

180221
# 6. Save datasets
181222
dir_name = f"data/sky-t1-preview-{dataset_names[i]}"
182223
datasets[i] = datasets[i].materialize()
183224
datasets[i].write_json(os.path.abspath(dir_name))
225+
226+
227+
# 7. Union
228+
229+
# final_dataset = datasets[0].union(*datasets[1:])
230+
# dir_name = f"data/sky-t1-preview-full"
231+
# # save in folder as a single JSON file
232+
# final_dataset.repartition(1).write_json(os.path.abspath(dir_name))

0 commit comments

Comments
 (0)