diff --git a/configs/train_configs/dpo/tulu_3_preview_test_if_faeze.yaml b/configs/train_configs/dpo/tulu_3_preview_test_if_faeze.yaml new file mode 100644 index 000000000..2d9c669cc --- /dev/null +++ b/configs/train_configs/dpo/tulu_3_preview_test_if_faeze.yaml @@ -0,0 +1,31 @@ +model_name_or_path: /model +model_revision: main +use_flash_attn: true +gradient_checkpointing: true +tokenizer_name: /model +use_slow_tokenizer: true +dataset_mixer: + allenai/ultrafeedback_binarized_cleaned_train: 1.0 + ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0 + ai2-adapt-dev/WildChat-prefs-280824: 1.0 + ai2-adapt-dev/personahub_if_pref_data_v1: 1.0 + # ai2-adapt-dev/ultrafeedback-replication-p5: 1.0 + # ai2-adapt-dev/numina_math_gsm8k_prefs_balance_minerva_format_v2: 1.0 +max_seq_length: 2048 +preprocessing_num_workers: 16 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128 +learning_rate: 5.0e-7 +lr_scheduler_type: linear +warmup_ratio: 0.1 +weight_decay: 0.0 +num_train_epochs: 1 +output_dir: /output +with_tracking: true +report_to: + - wandb +logging_steps: 1 +use_lora: false +dpo_loss_type: dpo_norm +dpo_beta: 5 +checkpointing_steps: 2000 diff --git a/configs/train_configs/sft/tulu3_8b_preview_mix_v3.6.yaml b/configs/train_configs/sft/tulu3_8b_preview_mix_v3.6.yaml new file mode 100644 index 000000000..c22f8425b --- /dev/null +++ b/configs/train_configs/sft/tulu3_8b_preview_mix_v3.6.yaml @@ -0,0 +1,183 @@ +model_name_or_path: meta-llama/Meta-Llama-3-8B +model_revision: main +use_flash_attn: true +tokenizer_name: meta-llama/Meta-Llama-3-8B +use_slow_tokenizer: true +dataset_mixer: + # This final 3.4 mix is based on the ablations in the `tulu3_8b_preview_mix_v3.4.x` folder + # In the end, we selected v3.4.23 that performs the best after applying DPO. + # ------------------------------------------------------ + # no_robot dataset, human written, for general chat. + # Total: 9500 + # Pro: created by surge ai with high cost, should be high quality. + # Con: small, not diverse enough, may not be in consistent style. + HuggingFaceH4/no_robots: 9500 + # ------------------------------------------------------ + # OpenAssistant dataset, human written, for general chat. + # Here, only the highest rated paths are extracted. + # Total: 7708 + # Pro: created and reviewed by human volunteers, has multi-turn chat. + # Con: small, still has some noise, the writting quality may not be as good/careful as paid workers, style consistency. + # TODO: need to check if this version corresponds to the highest rated paths. + allenai/openassistant-guanaco-reformatted: 7708 + # ------------------------------------------------------ + # LIMA dataset, human written, for general chat. + # Some instances were filtered in building Tulu 2, probably due to some identity keywords. + # Total: 1018 + # Pro: created by researchers at Meta, aiming for diversity and high quality. + # Con: small, they were created quite early so might not consider some of the latest answering styles of chatbot. + # natolambert/tulu-v2-sft-mixture-lima: 1018 + # ------------------------------------------------------ + # Aya dataset, human written, for general chat (multilingual). + # Total: 202362 + # Pro: created by ..., aiming for very diverse languages (). + # Con: answers may not be in the perfect style. + ai2-adapt-dev/aya_dataset-reformat: 202362 + # ------------------------------------------------------ + # Tulu hard-coded examples, human written, for identity-related questions. + # Total: 14 + # Pro: necessary to make Tulu aware of itself and its builders. + # Con: small, low coverage of possible questions from users. + # TODO: we should later find ways to replicate this multiple times. + ai2-adapt-dev/tulu_hard_coded_examples: 14 + # ------------------------------------------------------ + # CoT subset in FLAN v2, human (researchers) converted from existing datasets, for reasoning. + # Here, we use the subset processed in Tulu v2. + # Total: 48747 + # Pro: researchers converted from 9 chain-of-thought datasets about arithmetics, multi-hop reasoning, nli. + # Con: limited in the task type, written early, may have inconsistent styles compared to today's chatbot. + # natolambert/tulu-v2-sft-mixture-cot: 49747 + # ------------------------------------------------------ + # SciIFF dataset, human (researchers) converted from existing datasets, for scientific literature understanding. + # Here, we use the subset extracted by the author in building allenai/SciRIFF-train-mix. + # Total: 35357 + # Pro: researchers converted from existing datasets for 54 scientific literature understanding tasks + # Con: limited in the task type, may have inconsistent styles compared to today's chatbot. + # TODO: need to ablate and compare with the one in tulu 2 mixture natolambert/tulu-v2-sft-mixture-science + # natolambert/tulu-v2-sft-mixture-science: 7468 # original data slightly different + ai2-adapt-dev/SciRIFF-train-mix-science: 10000 + # ------------------------------------------------------ + # SlimOrca dataset, gpt4 generated, for general chat. + # Total: 517982 + # Pro: Paring FLAN v2 inputs with system prompts, and regenerating the outputs using GPT4, potentially in a better style. + # Con: GPT4 responses may contain errors, which may be mitagated by the filtering in SlimOrca + # TODO: need to need to ablate and compare with the 300K one Faeze created. may benefit from regeneration. + # ai2-adapt-dev/slim-orca-300k: 100000 + ai2-adapt-dev/SlimOrca-reformat: 100000 + # ------------------------------------------------------ + # WizardLM eval instruct dataset, gpt4 generated, for general chat. + # Total: 196000 + # Pro: the approach deepens the complexity of gpt4-generated data + # Con: GPT4 generations have eorrs, may also inheritate the biases/styles in GPT4 + # TODO: need to ablate. + WizardLMTeam/WizardLM_evol_instruct_V2_196k: 30000 + # ------------------------------------------------------ + # WildChat dataset, real user queries + gpt4 responses, for general chat. + # Total: 254663 (1M if including those interacting with gpt 3.5) + # Pro: real user queries, may contain diverse challenging scenarios, as well as unsafe prompts. Mutli-turn. + # Con: user queries are usually not that well-formated, and contain a lot of noises. + ai2-adapt-dev/WildChat-1M-Full-GPT4-Only: 254663 + # ------------------------------------------------------ + # ShareGPT dataset, real user shared queries + gpt4 responses, for general chat. + # Total: 114046 + # Pro: user shared queries usually contain interesting phenomena. Multi-turn. + # Con: unsure licensing, the responses were generated using earlier version of GPT4. + # TODO: need to ablate. May benefit from regeneration. + # Vtuber-plan/sharegpt-cleaned: 114046 + # ------------------------------------------------------ + # Daring-Anteater, a mix of existing datasets, for general chat. + # Total: 99532 + # Pro: a good mix of precise_instruction_following / json_format_following / complex instructions. + # Con: the constraint following part is too small. + # TODO: need to ablate if exclusing the main chat subset is helpful. + # TODO: data needs to be reformatted to consider the system prompt. + ai2-adapt-dev/Daring-Anteater-reformat: 99532 + # ------------------------------------------------------ + # MetaMathQA dataset, augmented using gpt4, for math capability. + # Total: 395000 + # Pro: augmented towards GSM/MATH, so good performance on these two benchmarks (probably similar questions too) + # Con: may be too targeted for the two benchmarks and fail to generalize to other math problems in different styles. + ai2-adapt-dev/metamath-qa-reformat: 100000 + # ------------------------------------------------------ + # WebInstruct dataset, extract&rewritten using gpt4, (mainly) for math/science related questions + # Here, we are using their released subset. + # Total: 2335220 + # Pro: the generation benefits from GPT4 answering style & the correctness of grounding to web documents. + # Con: may be biased by the response styles in the three websites (MathStackExchange, ScienceStackExchange, Socratic); + # the question answering style are also not diverse enough, with different instruction constraints; + # the answer may still have some errors (10% based on the paper) + # TODO: need to ablate the effect. + ai2-adapt-dev/WebInstructSub-reformat: 100000 + # ------------------------------------------------------ + # Codefeedback Filtered Instruction, a mix of existing dataset, for coding + # The data mix includes: + # Magicoder-OSS-Instruct + # Python code subset of ShareGPT + # Magicoder-Evol-Instruct + # Evol-Instruct-Code + # Total: 156526 + # Pro: a decent mix of existing coding prompts + # Con: curated mainly for the prompts in building the real CodeFeedback, so responses may be low quality (e.g., ShareGPT) + # TODO: change to individual dataset and ablate the effect. may benefit from regeneration. + m-a-p/CodeFeedback-Filtered-Instruction: 156526 + # ------------------------------------------------------ + # Codefeedback dataset, a mix of existing dataset + feedback interaction generation, for coding + # Total: 66383 + # Pro: single-turn packing + interaction simulation seems to create good coding model that takes feedback in multi turn. + # Con: not sure how diverse the feedback is and how well it can generalize + # TODO: need to ablate. need to change code for downweight the intermediate responses with errors!!! + # m-a-p/Code-Feedback: 66383 + # ------------------------------------------------------ + # Table-GPT dataset, converted & synthesized, for table understanding and operations + # Total: 13222 + # Pro: a special dataset that contains 14 table related tasks for enhancing table capabilities + # Con: task types are limited. The tables may not be big enough. Reponse styles may be inconsistent. + # TODO: need to ablate. + ai2-adapt-dev/Table-GPT-All-train: 3000 + # ------------------------------------------------------ + # Coconot dataset, generated by gpt4, for non-compliance + # Total: 11477 + # Pro: a special dataset for the a comprehenvise list of non-compliance behaviors of models. + # Con: the generated queries may only reflect simple cases. + # TODO: need to ablate. + ai2-adapt-dev/coconot-sft-reformat: 11477 + # ------------------------------------------------------ + # NuminaMATH-TIR, extracted and generated by gpt4, for tool-integrated reasoning for math + # Total: 72441 + # Pro: generally high-quality dataset with mined prompts from web corpus, verified tool-integrated reasoning trajatories. + # Con: mainly for solving math in a specific format, not in a consistent format with the general chat. + # TODO: need to ablate. need to rewrite!!! + AI-MO/NuminaMath-TIR: 72441 + # AI-MO/NuminaMath-CoT: 859000 + # ------------------------------------------------------ + # Xlam function calling dataset, synthesized and verified, for tool use + # Total: 60000 + # Pro: a special dataset for enhancing function calling capability, good performance on BFCL + # Con: responses only contain the function calling and arguments, not in a consistent style with the general chat. + # TODO: need to ablate. need to rewrite!!! + # Salesforce/xlam-function-calling-60k: 60000 + # ------------------------------------------------------ + # Lmsys chatbot arena data, human queries for challenging models, for general chat. + # Total: 1000000 + # Pro: real human interaction with model, with reasonable challenges. + # Con: may not reflect the real challenges in actual use of AI models. The interactions include those with weak models. + # TODO: need to ablate. need to regenerate (the last step)!! the intermediate low-quality responese need to downweight. + # lmsys/lmsys-chat-1m: 1000000 + ai2-adapt-dev/personahub_ifdata_v1_29980: 29980 + # ai2-adapt-dev/personahub_math_v4_149975: 149975 +max_seq_length: 4096 # Note, reduced from 8192 to fit on one GPU with DeepSpeed Stage3 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +gradient_accumulation_steps: 4 # effective batch size 128 with 4 nodes +learning_rate: 5.0e-06 # best LR so far +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 2 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +dataset_mix_dir: /output/ diff --git a/configs/train_configs/sft/tulu3_L3.1_8b_math_faeze.yaml b/configs/train_configs/sft/tulu3_L3.1_8b_math_faeze.yaml new file mode 100644 index 000000000..76c2c71ab --- /dev/null +++ b/configs/train_configs/sft/tulu3_L3.1_8b_math_faeze.yaml @@ -0,0 +1,43 @@ +model_name_or_path: meta-llama/Meta-Llama-3.1-8B +model_revision: main +use_flash_attn: true +tokenizer_name: meta-llama/Meta-Llama-3.1-8B +use_slow_tokenizer: true +# model_name_or_path: Qwen/Qwen2-7B +# model_revision: main +# use_flash_attn: true +# tokenizer_name: Qwen/Qwen2-7B +# use_slow_tokenizer: true +dataset_mixer: + natolambert/tulu-v2-sft-mixture-flan: 50000 + natolambert/tulu-v2-sft-mixture-cot: 49747 + # ai2-adapt-dev/personahub_math_v1: 49990 + # ai2-adapt-dev/personahub_math_v2_79975: 79975 + # ai2-adapt-dev/personahub_math_v3_119975: 119975 + ai2-adapt-dev/personahub_math_v4_149975: 149975 + # Vtuber-plan/sharegpt-cleaned: 114046 + # vicgalle/alpaca-gpt4: 20000 + # HuggingFaceH4/CodeAlpaca_20K: 18000 + # natolambert/tulu-v2-sft-mixture-lima: 1018 + # natolambert/tulu-v2-sft-mixture-science: 7468 + AI-MO/NuminaMath-TIR: 72441 + # ai2-adapt-dev/numina_math_gsm8k_sampled_sft_llama3_405_regen: 8937 + # ai2-adapt-dev/numina_math_gsm8k_sampled_sft_gold: 8937 + # ai2-adapt-dev/numina_math_gsm8k_prefs_balance_minerva_format_v2_messages_format: 22841 + # ai2-adapt-dev/math_numina_balanced_none_mc_prefs_minerva_format_messages_format: 41603 +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +gradient_accumulation_steps: 8 # effective batch size 128 with 4 nodes +learning_rate: 5.0e-06 # best LR so far +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 2 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +dataset_mix_dir: /output/ diff --git a/mason.py b/mason.py index 9669640fc..f3bba77ae 100644 --- a/mason.py +++ b/mason.py @@ -250,6 +250,19 @@ def make_task_spec(args, command, i, beaker_secrets, whoami, resumable: bool): if not args.pure_docker_mode: setup_commands += f"cd {os.getcwd()} && " fully_command = setup_commands + " ".join(full_command) + # conda_command = [ + # os.getenv("CONDA_EXE"), + # "run", + # "--cwd", + # os.getcwd(), + # "--no-capture-output", + # "--name", + # os.getenv("CONDA_DEFAULT_ENV"), + # ] + # if not args.pure_docker_mode: + # setup_commands += f"cd {os.getcwd()} && " + # # fully_command = setup_commands + " ".join(full_command) + # fully_command = " ".join(conda_command) + " "+ " ".join(full_command) print(f"{full_command=}") diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 692e5358e..fc77a9996 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -81,6 +81,13 @@ BINARY_LABEL_KEY, ] +# prm dataset +INPUT_IDS_KEY = "input_ids" +# ATTENTION_MASK_KEY = "attention_mask" +# PRM_INPUT_KEY = "messages" +PRM_LABEL_KEY = "step_labels" +# LABELS_KEY = "labels" + # Chat templates # flake8: noqa # note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}` @@ -107,6 +114,13 @@ "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" ), + "simple_prm": ( + "{% for message in messages %}" + "{{'\n\n' if not loop.first else ''}}" + "{{message['content']}}" + "{% if loop.last and not add_generation_prompt %}{% endif %}" + "{% endfor %}" + ), "zephyr": ( "{% for message in messages %}\n" "{% if message['role'] == 'user' %}\n" @@ -153,6 +167,10 @@ class DatasetConfig: # columns names for SFT dataset sft_messages_key: str = SFT_MESSAGE_KEY + # columns names for PRM dataset + # prm_input_key: str = PRM_INPUT_KEY + prm_label_key: str = PRM_LABEL_KEY + # columns names for binary dataset binary_messages_key: str = SFT_MESSAGE_KEY label: str = BINARY_LABEL_KEY @@ -404,6 +422,60 @@ def get_token_length_visualization(self, dataset: DatasetDict, save_path: str = ) +class PRMDatasetProcessor(DatasetProcessor): + def tokenize(self, dataset: Dataset): + def tokenize_fn(row): + # row[INPUT_IDS_PROMPT_KEY] = self.tokenizer.apply_chat_template( + # row[self.config.prm_input_key], + # add_generation_prompt=False, + # ) + row[INPUT_IDS_KEY] = self.tokenizer.apply_chat_template(row[self.config.sft_messages_key]) + row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) + labels = copy.deepcopy(row[PRM_LABEL_KEY]) + # if self.config.train_only_on_prompt: + # labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) + row[LABELS_KEY] = labels + return row + + return dataset.map( + tokenize_fn, + num_proc=get_num_proc(len(dataset), self.config.num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), + load_from_cache_file=self.config.load_from_cache_file, + desc="Tokenizing and reformatting PRM data", + ) + + def filter(self, dataset: Dataset): + def filter_fn(row): + max_prompt_token_length_ok = True + # if self.config.max_prompt_token_lenth is not None: + # max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth + + max_token_length_ok = True + if self.config.max_token_length is not None: + max_token_length_ok = len(row[INPUT_IDS_KEY]) <= self.config.max_token_length + + contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) + return max_prompt_token_length_ok and max_token_length_ok and contain_some_labels + + return dataset.filter( + filter_fn, + num_proc=get_num_proc(len(dataset), self.config.num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU), + load_from_cache_file=self.config.load_from_cache_file, + desc="Filtering PRM data", + ) + + def get_token_length_stats(self, dataset: Union[Dataset, DatasetDict]): + return super().get_token_length_stats(features=[INPUT_IDS_PROMPT_KEY, INPUT_IDS_KEY], dataset=dataset) + + def get_token_length_visualization(self, dataset: DatasetDict, save_path: str = "tmp.png", bins: int = 30): + return super().get_token_length_visualization( + features=[INPUT_IDS_PROMPT_KEY, INPUT_IDS_KEY], + dataset=dataset, + save_path=save_path, + bins=bins, + ) + + def convert_preference_dataset_to_binary_dataset(ds: Dataset): binary_ds = defaultdict(list) for i in tqdm(range(len(ds))): @@ -468,6 +540,70 @@ def __call__(self, batch: List[Dict[str, int]]): } +class SimplePRMCollator: + def __init__(self, pad_token_id: int): + """Simple collator for preference dataset (always pad from the RIGHT)""" + self.pad_token_id = pad_token_id + + def __call__(self, batch: List[Dict[str, int]]): + """the input will have input_ids_chosen, input_ids_rejected""" + # Find max length in the batch + max_length_chosen = -1 + for i in range(len(batch)): + max_length_chosen = max(max_length_chosen, len(batch[i]["input_ids"])) + # max_length_rejected = max(max_length_rejected, len(batch[i]["input_ids_rejected"])) + max_length = max_length_chosen #max(max_length_chosen, max_length_rejected) + assert max_length > 0, "the dataset is empty" + + max_label_length = self.find_max_label_length(batch) + # print(max_label_length) + + # Initialize lists to store padded sequences and attention masks + padded_sequences_chosen = [] + labels = [] + # padded_sequences_rejected = [] + for i in range(len(batch)): + # Calculate padding length + pad_length_chosen = max_length - len(batch[i][INPUT_IDS_KEY]) + # pad_length_rejected = max_length - len(batch[i][INPUT_IDS_REJECTED_KEY]) + + # Pad from the right + padding_chosen = [self.pad_token_id] * pad_length_chosen + # padding_rejected = [self.pad_token_id] * pad_length_rejected + padded_sequence_chosen = batch[i][INPUT_IDS_KEY] + padding_chosen + # padded_sequence_rejected = batch[i][INPUT_IDS_REJECTED_KEY] + padding_rejected + padded_sequences_chosen.append(padded_sequence_chosen) + + # pad labels + pad_length_label = max_label_length - len(batch[i][LABELS_KEY]) + # print(pad_length_label) + padding_label = [self.pad_token_id] * pad_length_label + # print(padding_label) + padded_sequence_label = batch[i][LABELS_KEY] + padding_label + labels.append(padded_sequence_label) + # print(padded_sequence_label, "------") + # padded_sequences_rejected.append(padded_sequence_rejected) + + # Convert to tensors + padded_sequences_chosen = torch.tensor(padded_sequences_chosen) + labels = torch.tensor(labels) + # padded_sequences_rejected = torch.tensor(padded_sequences_rejected) + + return { + INPUT_IDS_KEY: padded_sequences_chosen, + LABELS_KEY: labels + # INPUT_IDS_REJECTED_KEY: padded_sequences_rejected, + } + + def find_max_label_length(self, batch): + max_length= -1 + for i in range(len(batch)): + max_length = max(max_length, len(batch[i]["labels"])) + # max_length_rejected = max(max_length_rejected, len(batch[i]["input_ids_rejected"])) + max_length_label = max_length #max(max_length_chosen, max_length_rejected) + assert max_length_label > 0, "the label list is empty" + return max_length_label + class SimpleGenerateCollator: """Simple collator for generation task (always pad from the LEFT)""" diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index c858c3e53..62dc402d2 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -135,7 +135,6 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor: # The returned tensor has shape (batch_size,) return torch.min(zero_or_index, dim=-1).values - def get_reward( model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -206,6 +205,95 @@ def get_reward( ) +def get_prm_reward( + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, end_step_token_id: int ,context_length: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function computes reward scores for a batch of query responses based on a pre-trained reward model. + + Args: + model (torch.nn.Module): The pre-trained reward model. + query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards. + Shape: (batch_size, sequence_length) + pad_token_id (int): The ID used for padding tokens in the tokenized sequences. + context_length (int): The length of the prompt or context preceding the completions. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - reward_logits: The logits output from the model for all tokens in the sequences. + Shape: (batch_size, sequence_length) + - final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths. + Shape: (batch_size,) + - sequence_lengths: The lengths of each sequence (excluding padding). + Shape: (batch_size,) + """ + + # Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0 + # Shape: (batch_size, sequence_length) + attention_mask = query_responses != pad_token_id + + # Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding) + # Shape: (batch_size, sequence_length) + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + + # Access the LM backbone from the reward model using its base model prefix + lm_backbone = getattr(model, model.base_model_prefix) + + # Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing) + # Shape: (batch_size, sequence_length) + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + use_cache=False, # otherwise mistral-based RM would error out + ) + reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length) + + # Calculate the length of each sequence by finding the first occurrence of a padding token after the context + # sequence_lengths shape: (batch_size,) + + sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length + + # breakpoint() + all_sequence_lengths = (query_responses[:, context_length:] == end_step_token_id).nonzero(as_tuple=False) + row = all_sequence_lengths.shape[0] + # subtract one (no need for prm case whe we need logit on the label token) + # all_sequence_lengths = all_sequence_lengths - torch.cat([torch.zeros(row,1), torch.ones(row,1)], dim=-1).to(query_responses.device) #+ context_length + all_sequence_lengths = all_sequence_lengths.to(query_responses.device).type(torch.long) + step_reward_logits = reward_logits[all_sequence_lengths[:,0], all_sequence_lengths[:,1]].squeeze(-1) + +#[0, 3] +#[0, 5] +#[0, 10] +#[1, 10] +#[1, 20] + + ## no needed for num_label==2 + # assert ( + # reward_logits.shape[-1] == 1 + # ), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`." + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + + # Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths + return ( + # reward_logits shape: (batch_size, sequence_length) + reward_logits, + # final_scores shape: (batch_size,) + reward_logits[ + torch.arange(reward_logits.size(0), device=reward_logits.device), + sequence_lengths, + ].squeeze( + -1 + ), # Shape: (batch_size,) + sequence_lengths, + step_reward_logits, # Shape (batshc_size*#special_tokens across btches,) + all_sequence_lengths + ) + + def forward( model: torch.nn.Module, query_responses: torch.Tensor, diff --git a/open_instruct/process_reward_modeling.py b/open_instruct/process_reward_modeling.py new file mode 100644 index 000000000..0384dbff3 --- /dev/null +++ b/open_instruct/process_reward_modeling.py @@ -0,0 +1,444 @@ +import json +import os +import random +import time +from dataclasses import asdict, dataclass +from typing import List, Literal, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import DatasetDict +from huggingface_hub import HfApi +from rich.pretty import pprint +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + get_scheduler, +) + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + INPUT_IDS_CHOSEN_KEY, + INPUT_IDS_REJECTED_KEY, + INPUT_IDS_PROMPT_KEY, + INPUT_IDS_KEY, + LABELS_KEY, + DatasetConfig, + PreferenceDatasetProcessor, + PRMDatasetProcessor, + SimplePRMCollator, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + disable_dropout_in_model, + get_reward, + get_prm_reward, + print_rich_single_line_metrics, + print_rich_table, + push_folder_to_hub, + save_with_accelerate, +) +from open_instruct.reward_modeling_eval import evaluate, evaluate_prm +from open_instruct.utils import ( + ArgumentParserPlus, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_wandb_entity, +) + +api = HfApi() + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: int = 8 + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = None + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_batch_size: Optional[int] = None + """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" + batch_size: Optional[int] = None + """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 1 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + + resize_token_embeddings: bool = True + """Whether to resize the token embeddings to a factor of 8 for utilizing tensor cores better""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args_and_accelerator(args: Args, model_config: ModelConfig) -> Accelerator: + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + args.world_size = accelerator.num_processes + args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + # set a unique run name with the current timestamp + time_int = broadcast(time_tensor, 0).item() + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = f"{args.exp_name}__{model_config.model_name_or_path.replace('/', '_')}" + if args.hf_entity is None: + args.hf_entity = api.whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking and accelerator.is_main_process: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + return accelerator + + +def layer_init(layer: nn.Module, std: float): + torch.nn.init.normal_(layer.weight, std=std) + return layer + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + accelerator = calculate_runtime_args_and_accelerator(args, model_config) + local_seed = args.seed + accelerator.process_index + + # set up experiment tracking and seeds + all_configs = {} + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) + beaker_config = maybe_get_beaker_config() + # try saving to the beaker `/output`, which will be uploaded to the beaker dataset + if len(beaker_config.beaker_dataset_id_urls) > 0: + args.output_dir = "/output" + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if accelerator.is_main_process: + # breakpoint() + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + device = accelerator.device + random.seed(local_seed) + np.random.seed(local_seed) + torch.manual_seed(local_seed) + torch.backends.cudnn.deterministic = True + + # create a tokenizer (pad from right) + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + # tokenizer.add_special_tokens({"end_step_token": "ки"}) + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + end_step_token_id = tokenizer.encode(" ки")[1] + + # create the dataset + dataset_dict = DatasetDict() + # dataset_processor = PreferenceDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + dataset_processor = PRMDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.prm_label_key], #[dataset_config.preference_chosen_key, dataset_config.preference_rejected_key], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + with accelerator.main_process_first(): + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.prm_label_key], #[dataset_config.preference_chosen_key, dataset_config.preference_rejected_key], + ) + eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) + with accelerator.main_process_first(): + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset) + dataset_dict["eval"] = eval_dataset + + # some more runtime logging + if args.total_episodes is None: + args.total_episodes = args.num_train_epochs * len(train_dataset) + args.num_training_steps = args.total_episodes // args.batch_size + args.eval_freq = max(1, args.total_episodes // args.micro_batch_size // args.num_evals) + if accelerator.is_main_process: + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_KEY], tokenizer) #INPUT_IDS_CHOSEN_KEY + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, num_labels=2 + ) + if args.resize_token_embeddings: # optimize for tensor core + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) + if model_config.gradient_checkpointing: + model.gradient_checkpointing_enable() + disable_dropout_in_model(model) # see p.3. in https://arxiv.org/pdf/1909.08593 + layer_init( + model.score, std=1 / np.sqrt(model.config.hidden_size + 1) + ) # see p. 11 in https://arxiv.org/abs/2009.01325 + optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, eps=args.eps) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs, + ) + + # TODO Faeze + data_collator = SimplePRMCollator(pad_token_id=tokenizer.pad_token_id) + dataloader = DataLoader( + train_dataset, + batch_size=args.per_device_train_batch_size, + shuffle=True, + collate_fn=data_collator, + ) + eval_dataloader = DataLoader( + eval_dataset, + batch_size=args.per_device_eval_batch_size, + shuffle=False, + collate_fn=data_collator, + ) + + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + eval_dataloader = accelerator.prepare(eval_dataloader) + torch.manual_seed(local_seed) + + # set up the metrics and initial states + losses = torch.zeros((args.gradient_accumulation_steps,), device=device) + accuracies = torch.zeros((args.gradient_accumulation_steps,), device=device) + # chosen_rewards = torch.zeros((args.gradient_accumulation_steps,), device=device) + # rejected_rewards = torch.zeros((args.gradient_accumulation_steps,), device=device) + # reward_margin = torch.zeros((args.gradient_accumulation_steps,), device=device) + local_metrics = torch.zeros((5,), device=device) + training_step = 0 + gradient_accumulation_idx = 0 + episode = 0 + model.train() + + # training loop + for _ in range(args.num_train_epochs): + for data in dataloader: + episode += args.micro_batch_size + training_step += 1 + # query_responses = torch.cat((data[INPUT_IDS_CHOSEN_KEY], data[INPUT_IDS_REJECTED_KEY]), dim=0) + query_responses = data[INPUT_IDS_KEY] + with accelerator.accumulate(model): + _, predicted_reward, _, predcited_step_logits, _ = get_prm_reward(model, query_responses, tokenizer.pad_token_id, end_step_token_id, 0) + # ### add binary cross entropy + ### find label tokens and flatten them into 1D vector to compute loss wrt predicted step logits that are flatten logits for each label position + flatten_labels = data[LABELS_KEY][data[LABELS_KEY] != tokenizer.pad_token_id].flatten() # data['label'][data['label'] != -1].flatten() + assert predcited_step_logits.shape[0] == flatten_labels.shape[0], f"predicted logits and labels size do not match {predcited_step_logits.shape} vs {flatten_labels.shape}" + loss = F.cross_entropy(predcited_step_logits, flatten_labels) + accelerator.backward(loss) + # if accelerator.is_main_process: + # breakpoint() + # else: + # time.sleep(100) + optimizer.step() + optimizer.zero_grad() + # Step 1: Get the predicted labels by finding the index of the maximum logit + _, predicted_labels = torch.max(predcited_step_logits, dim=1) + # Step 2: Compare the predicted labels with the actual targets + accuracy = (predicted_labels == flatten_labels).float().mean().item() + losses[gradient_accumulation_idx] = loss + accuracies[gradient_accumulation_idx] = accuracy + gradient_accumulation_idx = (gradient_accumulation_idx + 1) % args.gradient_accumulation_steps + + if training_step % args.gradient_accumulation_steps == 0: + scheduler.step() + local_metrics[0] = accuracies.mean() + local_metrics[1] = losses.mean() + # local_metrics[2] = chosen_rewards.mean() + # local_metrics[3] = rejected_rewards.mean() + # local_metrics[4] = reward_margin.mean() + global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + + metrics = { + "episode": episode, + "epoch": episode / len(train_dataset), + "train/rm/accuracy": global_metrics[0], + "train/rm/loss": global_metrics[1], + # "train/rm/chosen_rewards": global_metrics[2], + # "train/rm/rejected_rewards": global_metrics[3], + # "train/rm/reward_margin": global_metrics[4], + "train/rm/lr": scheduler.get_last_lr()[0], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + # (optionally) evaluate the model + if args.num_evals > 0 and training_step > 1 and training_step % args.eval_freq == 0: + eval_metrics, table = evaluate_prm(model, eval_dataloader, tokenizer, max_sampled_texts=10) + for key in table: + table[key] = gather_object(table[key]) + # if accelerator.is_main_process: + # breakpoint() + df = pd.DataFrame(table) + if accelerator.is_main_process: + print_rich_single_line_metrics(eval_metrics) + for key, value in eval_metrics.items(): + writer.add_scalar(key, value, episode) + if args.with_tracking: + wandb.log({"preference_sample_texts": wandb.Table(dataframe=df)}) + else: + print_rich_table(df) + print_rich_table(df) + + # save model + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + save_with_accelerate( + accelerator, + model, + original_tokenizer, + args.output_dir, + ) + if args.push_to_hub: + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) + diff --git a/open_instruct/rejection_sampling/rejection_sampling.py b/open_instruct/rejection_sampling/rejection_sampling.py index a5f47123a..cb66638d0 100644 --- a/open_instruct/rejection_sampling/rejection_sampling.py +++ b/open_instruct/rejection_sampling/rejection_sampling.py @@ -301,7 +301,7 @@ def main(args: Args): item = result.get() scores.append(item[0]) reference_completion_scores.append(item[1]) - + breakpoint() # Combine scores from all GPUs scores = torch.cat(scores) reference_completion_scores = torch.cat(reference_completion_scores) diff --git a/open_instruct/reward_modeling_eval.py b/open_instruct/reward_modeling_eval.py index fdd85ae16..17769efb6 100644 --- a/open_instruct/reward_modeling_eval.py +++ b/open_instruct/reward_modeling_eval.py @@ -19,11 +19,16 @@ CHAT_TEMPLATES, INPUT_IDS_CHOSEN_KEY, INPUT_IDS_REJECTED_KEY, + INPUT_IDS_KEY, + LABELS_KEY, DatasetConfig, PreferenceDatasetProcessor, SimplePreferenceCollator, + PRMDatasetProcessor, + SimplePRMCollator + ) -from open_instruct.model_utils import get_reward, print_rich_table +from open_instruct.model_utils import get_reward, get_prm_reward, print_rich_table api = HfApi() @@ -104,12 +109,90 @@ def evaluate( "eval/rm/reward_margin": total_reward_margin / total_batches, }, table +decompose_list = lambda long_list, lengths: [long_list[sum(lengths[:i]):sum(lengths[:i+1])] for i in range(len(lengths))] +def evaluate_prm( + model: PreTrainedModel, dataloader: DataLoader, tokenizer: PreTrainedTokenizer, max_sampled_texts: int = 0 +) -> Tuple[dict, dict]: + model.eval() + total_loss = 0 + total_accuracy = 0 + # total_chosen_rewards = 0 + # total_rejected_rewards = 0 + # total_reward_margin = 0 + total_batches = 0 + table = None + if max_sampled_texts > 0: + table = defaultdict(list) + end_step_token_id = tokenizer.encode(" ки")[1] + with torch.no_grad(): + for data in tqdm(dataloader): + query_responses = data[INPUT_IDS_KEY] + flatten_labels = data[LABELS_KEY][data[LABELS_KEY] != tokenizer.pad_token_id].flatten() # data['label'][data['label'] != -1].flatten() + input_label_lengths = (data[LABELS_KEY] != tokenizer.pad_token_id).sum(dim=1).tolist() + _, predicted_reward, _, predcited_step_logits, _ = get_prm_reward(model, query_responses, tokenizer.pad_token_id, end_step_token_id, 0) + assert predcited_step_logits.shape[0] == flatten_labels.shape[0], f"predicted logits and labels size do not match {predcited_step_logits.shape} vs {flatten_labels.shape}" + _, predicted_labels = torch.max(predcited_step_logits, dim=1) + accuracy = (predicted_labels == flatten_labels).float().mean() + loss = F.cross_entropy(predcited_step_logits, flatten_labels) + # predicted_probs = predcited_step_logits[: data[LABELS_KEY].shape[0]] #TODO: debug this + + + # chosen_rewards = predicted_reward[: data[INPUT_IDS_CHOSEN_KEY].shape[0]] + # rejected_rewards = predicted_reward[data[INPUT_IDS_CHOSEN_KEY].shape[0] :] + # accuracy = (chosen_rewards > rejected_rewards).float().mean() + # loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean() + total_loss += loss.item() + total_accuracy += accuracy.item() + # total_chosen_rewards += chosen_rewards.mean().item() + # total_rejected_rewards += rejected_rewards.mean().item() + # total_reward_margin += (chosen_rewards - rejected_rewards).mean().item() + total_batches += 1 + + if table is not None and len(table["input text"]) < max_sampled_texts: + input_texts = tokenizer.batch_decode(data[INPUT_IDS_KEY]) + # remove padding + input_texts = [item.replace(tokenizer.pad_token, "") for item in input_texts] + # rewards_rounded = [ + # [round(chosen.item(), 4), round(rejected.item(), 4)] + # for chosen, rejected in zip(chosen_rewards, rejected_rewards) + # ] + correct_prediction = [ + bool((pred_label == gold_label)) for pred_label, gold_label in zip(predicted_labels, flatten_labels) + ] + correct_prediction_per_input = decompose_list(correct_prediction, input_label_lengths) + # shared_texts = [ + # find_shared_text(chosen_text, rejected_text) + # for chosen_text, rejected_text in zip(chosen_texts, rejected_texts) + # ] + # chosen_response_texts = [ + # chosen_text[len(shared_text) :] for chosen_text, shared_text in zip(chosen_texts, shared_texts) + # ] + # rejected_response_texts = [ + # rejected_text[len(shared_text) :] + # for rejected_text, shared_text in zip(rejected_texts, shared_texts) + # ] + table["input text"].extend(input_texts) + # table["chosen response text"].extend(chosen_response_texts) + # table["rejected response text"].extend(rejected_response_texts) + # table["chosen reward, rejected reward"].extend(rewards_rounded) + table["correct prediction"].extend(correct_prediction_per_input) + + model.train() + return { + "eval/rm/accuracy": total_accuracy / total_batches, + "eval/rm/loss": total_loss / total_batches, + # "eval/rm/chosen_rewards": total_chosen_rewards / total_batches, + # "eval/rm/rejected_rewards": total_rejected_rewards / total_batches, + # "eval/rm/reward_margin": total_reward_margin / total_batches, + }, table + if __name__ == "__main__": model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) dataset_config = DatasetConfig( dataset_name="trl-internal-testing/sentiment-trl-style", chat_template="simple_chat" ) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m", padding_side="right") tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] @@ -117,14 +200,16 @@ def evaluate( dataset_processor = PreferenceDatasetProcessor( tokenizer=tokenizer, config=dataset_config, - ) + ) if 'math' not in dataset_config.dataset_name else PRMDatasetProcessor(tokenizer=tokenizer, config=dataset_config) eval_dataset = dataset_processor.tokenize(eval_dataset) dataloader = DataLoader( eval_dataset, batch_size=8, - collate_fn=SimplePreferenceCollator(tokenizer.pad_token_id), + collate_fn=SimplePreferenceCollator(tokenizer.pad_token_id) if 'math' not in dataset_config.dataset_name else SimplePRMCollator(pad_token_id=tokenizer.pad_token_id), ) - metrics, table = evaluate(model, dataloader, tokenizer, max_sampled_texts=5) + + eval_func = evaluate if 'math' not in dataset_config.dataset_name else evaluate_prm + metrics, table = eval_func(model, dataloader, tokenizer, max_sampled_texts=5) print(metrics) print_rich_table(pd.DataFrame(table)) ... diff --git a/scripts/run_process_reward.sh b/scripts/run_process_reward.sh new file mode 100644 index 000000000..99a0afd56 --- /dev/null +++ b/scripts/run_process_reward.sh @@ -0,0 +1,64 @@ +#!/bin/bash +NUM_GPU=8 +model_name=EleutherAI/llemma_34b #AI-MO/NuminaMath-7B-CoT # # #deepseek-ai/deepseek-math-7b-instruct +DESC="math_reward_model_${model_name}_on_numina_math_gsm8k_v3" + +# --image costah/open_instruct_dev --pure_docker_mode \ +# python mason.py \ +# --cluster ai2/pluto-cirrascale \ +# --priority normal \ +# --budget ai2/allennlp \ +# --workspace ai2/tulu-2-improvements \ +# --description $DESC \ +# --gpus $NUM_GPU -- accelerate launch --num_machines 1 --num_processes $NUM_GPU --config_file configs/ds_configs/deepspeed_zero3.yaml \ +# open_instruct/reward_modeling.py \ +# --dataset_mixer '{"ai2-adapt-dev/numina_math_gsm8k_minerva_RM": 1.0}' \ +# --dataset_train_splits train \ +# --dataset_eval_mixer '{"ai2-adapt-dev/numina_math_gsm8k_minerva_RM": 1.0}' \ +# --dataset_eval_splits test \ +# --model_name_or_path $model_name \ +# --chat_template simple_concat_with_space \ +# --learning_rate 3e-6 \ +# --gradient_checkpointing \ +# --per_device_train_batch_size 1 \ +# --per_device_eval_batch_size 4 \ +# --gradient_accumulation_steps 8 \ +# --max_token_length 4096 \ +# --max_prompt_token_lenth 512 \ +# --num_train_epochs 1 \ +# --output_dir outputs/rm/rm_math_7b \ +# --with_tracking \ +# # --push_to_hub + +#ai2-adapt-dev/Math-Shepherd-PRM-format +# "ai2-adapt-dev/numina_math_gsm8k_minerva_RM" + +model_name=deepseek-ai/deepseek-math-7b-instruct +DESC="math__PRM_modeling_${model_name}_on_math_shepherd_8_2epoch" +NUM_GPU=8 +# accelerate launch --num_processes $NUM_GPU --config_file configs/ds_configs/deepspeed_zero2.yaml open_instruct/reward_modeling_v2.py \ +python mason.py \ + --cluster ai2/pluto-cirrascale \ + --priority normal \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --description $DESC \ + --gpus $NUM_GPU -- accelerate launch --num_machines 1 --num_processes $NUM_GPU --config_file configs/ds_configs/deepspeed_zero3.yaml open_instruct/process_reward_modeling.py \ + --dataset_mixer '{"ai2-adapt-dev/Math-Shepherd-PRM-chat-reformatted": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"ai2-adapt-dev/Math-Shepherd-PRM-chat-reformatted": 1.0}' \ + --dataset_eval_splits test \ + --model_name_or_path $model_name \ + --chat_template simple_concat_with_space \ + --learning_rate 3e-6 \ + --gradient_checkpointing \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --max_token_length 4096 \ + --max_prompt_token_lenth 512 \ + --num_train_epochs 2 \ + --output_dir ./outputs/rm/rm_math_7b \ + --chat_template simple_prm + # --sanity_check \ +# --with_tracking \