-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
Thanks for your great work!
I noticed that the git repository doesn't seem to include the implementation of the trajectory-level reward calculation mentioned in Equation (11). According to the code in fsdp_workers.py, only the reward from the first step is selected to represent the overall reward score. Is this the intended implementation? Could you please clarify how the trajectory-level reward is computed in practice?
ReasonFlux/ReasonFlux_PRM/Application/verl/workers/fsdp_workers.py
Lines 1358 to 1375 in c6605f7
def make_step_rewards(logits, token_masks): | |
probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
probabilities = probabilities * token_masks.unsqueeze(-1) | |
all_scores_res = [] | |
for i in range(probabilities.size(0)): | |
sample = probabilities[i] | |
positive_probs = sample[sample != 0].view(-1, 2)[:, 1] | |
non_zero_elements_list = positive_probs.cpu().tolist() | |
all_scores_res.append(non_zero_elements_list) | |
return all_scores_res | |
step_sep_id = self.tokenizer.encode("<extra_0>")[0] | |
token_masks = (micro_batch['input_ids'] == step_sep_id) | |
step_reward = make_step_rewards(rm_score[0], token_masks) | |
# output.append(rm_score) | |
output.append(step_reward[0][0]) | |
BETA = 0.75 | |
scores = torch.tensor([BETA * o1 + (1-BETA) * o2 for (o1,o2) in zip(output,rulebased_res)]) |
Metadata
Metadata
Assignees
Labels
No labels