diff --git a/sweep.py b/sweep.py index d414b9a..9da4420 100644 --- a/sweep.py +++ b/sweep.py @@ -22,17 +22,17 @@ def main(model_sizes: Union[List[str], str], **kwargs): for model_size in model_sizes: subprocess.run(basic_args + ["--model_size", model_size], check=True) - print("Running transfer models") - for i in range(len(model_sizes)): - for j in range(i, len(model_sizes)): - weak_model_size = model_sizes[i] - strong_model_size = model_sizes[j] - print(f"Running weak {weak_model_size} to strong {strong_model_size}") - subprocess.run( - basic_args - + ["--weak_model_size", weak_model_size, "--model_size", strong_model_size], - check=True, - ) + # print("Running transfer models") + # for i in range(len(model_sizes)): + # for j in range(i, len(model_sizes)): + # weak_model_size = model_sizes[i] + # strong_model_size = model_sizes[j] + # print(f"Running weak {weak_model_size} to strong {strong_model_size}") + # subprocess.run( + # basic_args + # + ["--weak_model_size", weak_model_size, "--model_size", strong_model_size], + # check=True, + # ) if __name__ == "__main__": diff --git a/test_sweep.py b/test_sweep.py new file mode 100644 index 0000000..e18935d --- /dev/null +++ b/test_sweep.py @@ -0,0 +1,18 @@ +import subprocess +import sys + +def main(dataset_name, model_sizes, num_steps=600): # Set default steps to 600 + if isinstance(model_sizes, str): + model_sizes = model_sizes.split(',') + + # Base command that includes the script to run and the fixed number of steps + basic_args = [sys.executable, "train_simple.py", "--num_steps", str(num_steps), "--ds_name", dataset_name] + + # Loop over all specified model sizes and run the training script for each one + for model_size in model_sizes: + model_specific_args = basic_args + ["--model_size", model_size] + subprocess.run(model_specific_args, check=True) + +if __name__ == "__main__": + import fire + fire.Fire(main) diff --git a/test_train_simple.py b/test_train_simple.py new file mode 100644 index 0000000..b1bca4d --- /dev/null +++ b/test_train_simple.py @@ -0,0 +1,166 @@ +import json +import os +import argparse +import torch +from datasets import load_dataset, load_from_disk +import numpy as np +import fire + +# Custom imports for model configuration and training, replace these with your actual modules +from weak_to_strong.train import ModelConfig, train_and_save_model +from weak_to_strong.datasets import VALID_DATASETS, tokenize_dataset +from weak_to_strong.common import get_tokenizer +from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss +import weak_to_strong.logger as logger + +# Model configurations - You should replace this with your actual configuration +MODEL_CONFIGS = [ + + ModelConfig( + name="gpt2", + default_lr=5e-5, + eval_batch_size=32, + ), + + ModelConfig( + name="gpt2-medium", + default_lr=5e-5, + eval_batch_size=32, + ), + ModelConfig( + name="gpt2-large", + default_lr=1e-5, + eval_batch_size=32, + ), + ModelConfig( + name="gpt2-xl", + default_lr=1e-5, + eval_batch_size=2, + gradient_checkpointing=True, + # Should use model_parallel on V100s (note: ironically if you have a single V100 it should run, + # but if you have multiple it won't run without model_parallel because of the overhead of data + # parallel training). + model_parallel=( + torch.cuda.get_device_properties(0).total_memory < 35e9 + and torch.cuda.device_count() > 1 + ), + ), + ModelConfig( + name="Qwen/Qwen-1_8B", + default_lr=1e-5, + eval_batch_size=2, + gradient_checkpointing=True, + model_parallel=( + torch.cuda.get_device_properties(0).total_memory < 35e9 + and torch.cuda.device_count() > 1 + ), + custom_kwargs={ + "trust_remote_code": True, + "bf16": torch.cuda.is_bf16_supported(), + "fp32": not torch.cuda.is_bf16_supported(), + "revision": "5fde88dff770a7d036847211f5d9d9705f0caa69", + }, + ), + ModelConfig( + name="Qwen/Qwen-7B", + default_lr=1e-5, + eval_batch_size=2, + gradient_checkpointing=True, + model_parallel=True, + # note: you will probably not be able to run this without many gpus + custom_kwargs={ + "trust_remote_code": True, + "bf16": torch.cuda.is_bf16_supported(), + "fp32": not torch.cuda.is_bf16_supported(), + "revision": "d4efd21e866b9cb3466cb65b963933f5e98016d1", + }, + ), + ModelConfig( + name="Qwen/Qwen-14B", + default_lr=1e-5, + eval_batch_size=2, + gradient_checkpointing=True, + model_parallel=True, + # note: you will probably not be able to run this bf16 support and without many gpus + custom_kwargs={ + "trust_remote_code": True, + "bf16": torch.cuda.is_bf16_supported(), + "fp32": not torch.cuda.is_bf16_supported(), + "revision": "8be2854218fea9054331e217fd26a06f3fd02004", + }, + ), + ModelConfig( + name="Qwen/Qwen-72B", + default_lr=1e-5, + eval_batch_size=1, + gradient_checkpointing=True, + model_parallel=True, + # note: you will probably not be able to run this without bf16 support and many gpus + custom_kwargs={ + "trust_remote_code": True, + "bf16": torch.cuda.is_bf16_supported(), + "fp32": not torch.cuda.is_bf16_supported(), + "revision": "fec78c0e3b3b10dd9f0ce775c34a686a3255a7d1", + }, + # This model is really big, save space by using adafactor. + # Note that even then it will take up ~60GB per GPU on an 8-GPU machine. + default_optimizer="adafactor", + ), +] + +# Construct a dictionary from model configurations for easy access +MODELS_DICT = {config.name: config for config in MODEL_CONFIGS} + +# Define available losses +loss_dict = { + "logconf": logconf_loss_fn(), + "product": product_loss_fn(), + "xent": xent_loss(), +} + +def main(): + parser = argparse.ArgumentParser(description="Train models on various datasets") + parser.add_argument('--num_steps', type=int, default=10000, help='Number of fixed training steps') + parser.add_argument('--ds_name', type=str, required=True, help='Dataset name') + parser.add_argument('--model_size', type=str, required=True, help='Model size to use for training') + args = parser.parse_args() + + assert args.ds_name in VALID_DATASETS, f"Dataset {args.ds_name} is not recognized. Valid datasets are: {VALID_DATASETS}" + model_config = MODELS_DICT.get(args.model_size, None) + assert model_config is not None, f"Model size {args.model_size} is not recognized." + + # Load dataset + dataset = load_dataset(args.ds_name) + train_dataset, test_dataset = dataset['train'], dataset['test'] + + # Tokenize datasets + tokenizer = get_tokenizer(model_config.name) + train_dataset = tokenize_dataset(train_dataset, tokenizer, max_ctx=1024) + test_dataset = tokenize_dataset(test_dataset, tokenizer, max_ctx=1024) + + # Setup DataLoader + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) + + # Configure logger, replace with your logging setup + logger.configure_logging() + + # Start training loop + for step in range(args.num_steps): + try: + # Assume each step processes one batch of data + batch = next(iter(train_loader)) + loss = train_and_save_model( + model_config=model_config, + batch=batch, + loss_fn=loss_dict['xent'] # Example using cross-entropy loss + ) + if step % 100 == 0: + print(f"Step {step}: Loss {loss}") + except StopIteration: + # If the dataset runs out of data, restart the DataLoader + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) + + print("Training complete.") + +if __name__ == "__main__": + fire.Fire(main) diff --git a/train_simple.py b/train_simple.py index e406d4f..2e342db 100644 --- a/train_simple.py +++ b/train_simple.py @@ -144,7 +144,7 @@ def shorten_value(value) -> str: def main( batch_size: int = 32, max_ctx: int = 1024, - ds_name: str = "sciq", + ds_name: str = "glue_cola", loss: str = "xent", n_docs: int = 20000, n_test_docs: int = 10000, @@ -193,16 +193,16 @@ def main( # The commented out terms are the ones that should not change final results config = { - "batch_size": batch_size, + "batch_size": batch_size, ## INTERESTED "max_ctx": max_ctx, "ds_name": ds_name, - "loss": loss, + "loss": loss, ## INTERESTED "n_docs": n_docs, "n_test_docs": n_test_docs, - "model_size": model_size, - "lr": lr, + "model_size": model_size, ## INTERESTED + "lr": lr, ## INTERESTED "optim": optim, - "epochs": epochs, + "epochs": epochs, ## INTERESTED # "force_retrain": force_retrain, "seed": seed, # "minibatch_size_per_device": minibatch_size_per_device, diff --git a/weak_to_strong/datasets.py b/weak_to_strong/datasets.py index 28d589f..3244b7a 100644 --- a/weak_to_strong/datasets.py +++ b/weak_to_strong/datasets.py @@ -153,7 +153,6 @@ def format_boolq(ex, rng): txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}" return dict(txt=txt, hard_label=hard_label) - register_dataset( "boolq", DatasetConfig( @@ -162,9 +161,66 @@ def format_boolq(ex, rng): ) +def format_openbookQA(ex, rng): + id = ex["id"] + question_stem = ex["question_stem"] + choices_text = ex['choices']['text'] + choices_labels = ex['choices']['label'] + correct_label = ex['answerKey'] + + choices_formatted = ' '.join([f"{label}: {text}" for label, text in zip(choices_labels, choices_text)]) + + correct_answer_index = choices_labels.index(correct_label) + correct_answer_text = choices_text[correct_answer_index] + + txt = f"Question: {question_stem}\nChoices: {choices_formatted}\nCorrect Answer: {correct_answer_text}" + return dict(txt=txt, hard_label=1) # have to change how hard label is coded. + +register_dataset( + "openbookqa", + DatasetConfig( + loader=hf_loader("allenai/openbookqa", "main", split_names=dict(test="validation")), formatter=format_openbookQA + ), +) +def format_paws(ex, rng): + txt = f"Sentence 1: {ex['sentence1']} Sentence 2: {ex['sentence2']}" + hard_label = int(ex['label']) + return dict(txt=txt, hard_label=hard_label) + +register_dataset( + "paws_labeled_final", # Unique name for the dataset registration. + DatasetConfig( + loader=hf_loader("paws", "labeled_final", split_names=dict(test="validation")), + formatter=format_paws + ), +) + +def format_glue_cola(ex, rng): + return dict(txt=ex['sentence'], hard_label=ex['label']) + +register_dataset( + "glue_cola", + DatasetConfig(loader=hf_loader("glue", "cola"), formatter=format_glue_cola), +) + VALID_DATASETS: list[str] = list(_REGISTRY.keys()) """ +def format_mctaco(ex, rng): + sentence = ex['sentence'] + question = ex['question'] + answer = ex['answer'] + label = int(ex['label'] == 'yes') # Convert 'yes'/'no' to binary label (1/0) + txt = f"Context: {sentence}\nQuestion: {question}\nAnswer: {answer}" + return dict(txt=txt, hard_label=label) + +register_dataset( + "mc_taco", + DatasetConfig( + loader=hf_loader("mc_taco", split_names=dict(test="validation")), formatter=format_mctaco + ), +) + from datasets import disable_caching disable_caching()