|
2149 | 2149 | " labels=batch[\"rejected\"],\n",
|
2150 | 2150 | " selection_mask=batch[\"rejected_mask\"]\n",
|
2151 | 2151 | " )\n",
|
2152 |
| - " ref_chosen_log_probas = compute_logprobs(\n", |
2153 |
| - " logits=reference_model(batch[\"chosen\"]),\n", |
2154 |
| - " labels=batch[\"chosen\"],\n", |
2155 |
| - " selection_mask=batch[\"chosen_mask\"]\n", |
2156 |
| - " )\n", |
2157 |
| - " ref_rejected_log_probas = compute_logprobs(\n", |
2158 |
| - " logits=reference_model(batch[\"rejected\"]),\n", |
2159 |
| - " labels=batch[\"rejected\"],\n", |
2160 |
| - " selection_mask=batch[\"rejected_mask\"]\n", |
2161 |
| - " )\n", |
| 2152 | + " \n", |
| 2153 | + " with torch.no_grad():\n", |
| 2154 | + " ref_chosen_log_probas = compute_logprobs(\n", |
| 2155 | + " logits=reference_model(batch[\"chosen\"]),\n", |
| 2156 | + " labels=batch[\"chosen\"],\n", |
| 2157 | + " selection_mask=batch[\"chosen_mask\"]\n", |
| 2158 | + " )\n", |
| 2159 | + " ref_rejected_log_probas = compute_logprobs(\n", |
| 2160 | + " logits=reference_model(batch[\"rejected\"]),\n", |
| 2161 | + " labels=batch[\"rejected\"],\n", |
| 2162 | + " selection_mask=batch[\"rejected_mask\"]\n", |
| 2163 | + " )\n", |
2162 | 2164 | " loss, chosen_rewards, rejected_rewards = compute_dpo_loss(\n",
|
2163 | 2165 | " model_chosen_logprobs=policy_chosen_log_probas,\n",
|
2164 | 2166 | " model_rejected_logprobs=policy_rejected_log_probas,\n",
|
|
3090 | 3092 | "name": "python",
|
3091 | 3093 | "nbconvert_exporter": "python",
|
3092 | 3094 | "pygments_lexer": "ipython3",
|
3093 |
| - "version": "3.11.4" |
| 3095 | + "version": "3.10.6" |
3094 | 3096 | }
|
3095 | 3097 | },
|
3096 | 3098 | "nbformat": 4,
|
|
0 commit comments