Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
adb3b58
initial commit
ant-mand Apr 16, 2024
1596ae7
changed sciq to boolq
ant-mand Apr 17, 2024
5158f4f
just run ground truth
ant-mand Apr 17, 2024
1f8ef9a
Expand datasets.py list.
elezovic-natalia Apr 17, 2024
d78fa03
changed formatting of mctaco
ant-mand Apr 17, 2024
6c3d8cc
deleted comma & extra paran
ant-mand Apr 17, 2024
37a4131
changed boolq to mctaco
ant-mand Apr 17, 2024
667dcb1
mctaco to mc_taco
ant-mand Apr 17, 2024
91709ec
mctaco to mc_taco
ant-mand Apr 17, 2024
c5fcff5
openbookQA register
ant-mand Apr 17, 2024
243456a
mctaco to openbookqa
ant-mand Apr 17, 2024
4bea1f1
Update datasets.py
ant-mand Apr 17, 2024
ca3a79b
Update datasets.py
ant-mand Apr 17, 2024
bd3a5e1
Update datasets.py
ant-mand Apr 17, 2024
1ca814c
hard label to 1.
ant-mand Apr 17, 2024
11afe51
making changes to openbookqa
ant-mand Apr 17, 2024
62ff348
Remove unvalid string.
elezovic-natalia Apr 17, 2024
f9e3b23
Update dataset.
elezovic-natalia Apr 17, 2024
d11f88b
Update paws.
elezovic-natalia Apr 17, 2024
3beb040
update datasets.py
elezovic-natalia Apr 17, 2024
910b0bf
update paws
elezovic-natalia Apr 17, 2024
78fd325
my previous changes
ant-mand Apr 17, 2024
e34751c
Merge branch 'main' of https://github.com/ant-mand/antrita-weak-to-st…
ant-mand Apr 17, 2024
2d1fdb8
boolq validation
ant-mand Apr 17, 2024
bffa1a9
updated datasets.py with glue cola.
elezovic-natalia Apr 17, 2024
b645d2d
glue_cola set up.
elezovic-natalia Apr 17, 2024
94d856f
troubleshooting ethics format in datasets.py.
elezovic-natalia Apr 17, 2024
e2ecfff
tweak training model with ethics dataset.
elezovic-natalia Apr 17, 2024
6aa3950
Update train_simple.py
elezovic-natalia Apr 27, 2024
44c3512
Update train_simple.py
elezovic-natalia Apr 30, 2024
0c5e0af
Update train_simple.py
elezovic-natalia Apr 30, 2024
c27c359
Fixed syntax
elezovic-natalia Apr 30, 2024
0eea4b6
Update train_simple.py
elezovic-natalia Apr 30, 2024
867cbdf
Update train_simple.py
elezovic-natalia Apr 30, 2024
a25a6fd
Update datasets.py
elezovic-natalia Apr 30, 2024
2d83062
Update datasets.py
elezovic-natalia Apr 30, 2024
3c82209
Update datasets.py
elezovic-natalia Apr 30, 2024
cc6d06b
Update datasets.py
elezovic-natalia Apr 30, 2024
2b80e71
Update datasets.py
elezovic-natalia Apr 30, 2024
9fe5e9a
Update datasets.py
elezovic-natalia Apr 30, 2024
74ffa68
Update train_simple.py
elezovic-natalia Apr 30, 2024
8aa1643
Update datasets.py
elezovic-natalia Apr 30, 2024
5d33d01
Corrected syntax
elezovic-natalia Apr 30, 2024
a6629bd
Update datasets.py
elezovic-natalia Apr 30, 2024
d6066a7
Update datasets.py
elezovic-natalia Apr 30, 2024
8cf0a39
Standardize training steps.
elezovic-natalia Apr 30, 2024
67d27aa
Revert to base
elezovic-natalia Apr 30, 2024
71e91cb
Update train_simple.py
elezovic-natalia Apr 30, 2024
36b41e8
Create test_sweep.py
elezovic-natalia May 1, 2024
58e3ad8
Update train_simple.py
elezovic-natalia May 1, 2024
9cf1a2c
Update train_simple.py
elezovic-natalia May 1, 2024
0a12965
Create test_train_simple.py
elezovic-natalia May 1, 2024
d568841
Update test_sweep.py
elezovic-natalia May 1, 2024
3520a6c
Update test_train_simple.py
elezovic-natalia May 1, 2024
7b2908e
Update test_sweep.py
elezovic-natalia May 1, 2024
29322b8
Update test_train_simple.py
elezovic-natalia May 1, 2024
7bb1f50
Update test_train_simple.py
elezovic-natalia May 1, 2024
f742a37
Update train_simple.py
elezovic-natalia May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def main(model_sizes: Union[List[str], str], **kwargs):
for model_size in model_sizes:
subprocess.run(basic_args + ["--model_size", model_size], check=True)

print("Running transfer models")
for i in range(len(model_sizes)):
for j in range(i, len(model_sizes)):
weak_model_size = model_sizes[i]
strong_model_size = model_sizes[j]
print(f"Running weak {weak_model_size} to strong {strong_model_size}")
subprocess.run(
basic_args
+ ["--weak_model_size", weak_model_size, "--model_size", strong_model_size],
check=True,
)
# print("Running transfer models")
# for i in range(len(model_sizes)):
# for j in range(i, len(model_sizes)):
# weak_model_size = model_sizes[i]
# strong_model_size = model_sizes[j]
# print(f"Running weak {weak_model_size} to strong {strong_model_size}")
# subprocess.run(
# basic_args
# + ["--weak_model_size", weak_model_size, "--model_size", strong_model_size],
# check=True,
# )


if __name__ == "__main__":
Expand Down
18 changes: 18 additions & 0 deletions test_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import subprocess
import sys

def main(dataset_name, model_sizes, num_steps=600): # Set default steps to 600
if isinstance(model_sizes, str):
model_sizes = model_sizes.split(',')

# Base command that includes the script to run and the fixed number of steps
basic_args = [sys.executable, "train_simple.py", "--num_steps", str(num_steps), "--ds_name", dataset_name]

# Loop over all specified model sizes and run the training script for each one
for model_size in model_sizes:
model_specific_args = basic_args + ["--model_size", model_size]
subprocess.run(model_specific_args, check=True)

if __name__ == "__main__":
import fire
fire.Fire(main)
166 changes: 166 additions & 0 deletions test_train_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import json
import os
import argparse
import torch
from datasets import load_dataset, load_from_disk
import numpy as np
import fire

# Custom imports for model configuration and training, replace these with your actual modules
from weak_to_strong.train import ModelConfig, train_and_save_model
from weak_to_strong.datasets import VALID_DATASETS, tokenize_dataset
from weak_to_strong.common import get_tokenizer
from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss
import weak_to_strong.logger as logger

# Model configurations - You should replace this with your actual configuration
MODEL_CONFIGS = [

ModelConfig(
name="gpt2",
default_lr=5e-5,
eval_batch_size=32,
),

ModelConfig(
name="gpt2-medium",
default_lr=5e-5,
eval_batch_size=32,
),
ModelConfig(
name="gpt2-large",
default_lr=1e-5,
eval_batch_size=32,
),
ModelConfig(
name="gpt2-xl",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
# Should use model_parallel on V100s (note: ironically if you have a single V100 it should run,
# but if you have multiple it won't run without model_parallel because of the overhead of data
# parallel training).
model_parallel=(
torch.cuda.get_device_properties(0).total_memory < 35e9
and torch.cuda.device_count() > 1
),
),
ModelConfig(
name="Qwen/Qwen-1_8B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=(
torch.cuda.get_device_properties(0).total_memory < 35e9
and torch.cuda.device_count() > 1
),
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "5fde88dff770a7d036847211f5d9d9705f0caa69",
},
),
ModelConfig(
name="Qwen/Qwen-7B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "d4efd21e866b9cb3466cb65b963933f5e98016d1",
},
),
ModelConfig(
name="Qwen/Qwen-14B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this bf16 support and without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "8be2854218fea9054331e217fd26a06f3fd02004",
},
),
ModelConfig(
name="Qwen/Qwen-72B",
default_lr=1e-5,
eval_batch_size=1,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "fec78c0e3b3b10dd9f0ce775c34a686a3255a7d1",
},
# This model is really big, save space by using adafactor.
# Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
default_optimizer="adafactor",
),
]

# Construct a dictionary from model configurations for easy access
MODELS_DICT = {config.name: config for config in MODEL_CONFIGS}

# Define available losses
loss_dict = {
"logconf": logconf_loss_fn(),
"product": product_loss_fn(),
"xent": xent_loss(),
}

def main():
parser = argparse.ArgumentParser(description="Train models on various datasets")
parser.add_argument('--num_steps', type=int, default=10000, help='Number of fixed training steps')
parser.add_argument('--ds_name', type=str, required=True, help='Dataset name')
parser.add_argument('--model_size', type=str, required=True, help='Model size to use for training')
args = parser.parse_args()

assert args.ds_name in VALID_DATASETS, f"Dataset {args.ds_name} is not recognized. Valid datasets are: {VALID_DATASETS}"
model_config = MODELS_DICT.get(args.model_size, None)
assert model_config is not None, f"Model size {args.model_size} is not recognized."

# Load dataset
dataset = load_dataset(args.ds_name)
train_dataset, test_dataset = dataset['train'], dataset['test']

# Tokenize datasets
tokenizer = get_tokenizer(model_config.name)
train_dataset = tokenize_dataset(train_dataset, tokenizer, max_ctx=1024)
test_dataset = tokenize_dataset(test_dataset, tokenizer, max_ctx=1024)

# Setup DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Configure logger, replace with your logging setup
logger.configure_logging()

# Start training loop
for step in range(args.num_steps):
try:
# Assume each step processes one batch of data
batch = next(iter(train_loader))
loss = train_and_save_model(
model_config=model_config,
batch=batch,
loss_fn=loss_dict['xent'] # Example using cross-entropy loss
)
if step % 100 == 0:
print(f"Step {step}: Loss {loss}")
except StopIteration:
# If the dataset runs out of data, restart the DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

print("Training complete.")

if __name__ == "__main__":
fire.Fire(main)
12 changes: 6 additions & 6 deletions train_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def shorten_value(value) -> str:
def main(
batch_size: int = 32,
max_ctx: int = 1024,
ds_name: str = "sciq",
ds_name: str = "glue_cola",
loss: str = "xent",
n_docs: int = 20000,
n_test_docs: int = 10000,
Expand Down Expand Up @@ -193,16 +193,16 @@ def main(

# The commented out terms are the ones that should not change final results
config = {
"batch_size": batch_size,
"batch_size": batch_size, ## INTERESTED
"max_ctx": max_ctx,
"ds_name": ds_name,
"loss": loss,
"loss": loss, ## INTERESTED
"n_docs": n_docs,
"n_test_docs": n_test_docs,
"model_size": model_size,
"lr": lr,
"model_size": model_size, ## INTERESTED
"lr": lr, ## INTERESTED
"optim": optim,
"epochs": epochs,
"epochs": epochs, ## INTERESTED
# "force_retrain": force_retrain,
"seed": seed,
# "minibatch_size_per_device": minibatch_size_per_device,
Expand Down
58 changes: 57 additions & 1 deletion weak_to_strong/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def format_boolq(ex, rng):
txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}"
return dict(txt=txt, hard_label=hard_label)


register_dataset(
"boolq",
DatasetConfig(
Expand All @@ -162,9 +161,66 @@ def format_boolq(ex, rng):
)


def format_openbookQA(ex, rng):
id = ex["id"]
question_stem = ex["question_stem"]
choices_text = ex['choices']['text']
choices_labels = ex['choices']['label']
correct_label = ex['answerKey']

choices_formatted = ' '.join([f"{label}: {text}" for label, text in zip(choices_labels, choices_text)])

correct_answer_index = choices_labels.index(correct_label)
correct_answer_text = choices_text[correct_answer_index]

txt = f"Question: {question_stem}\nChoices: {choices_formatted}\nCorrect Answer: {correct_answer_text}"
return dict(txt=txt, hard_label=1) # have to change how hard label is coded.

register_dataset(
"openbookqa",
DatasetConfig(
loader=hf_loader("allenai/openbookqa", "main", split_names=dict(test="validation")), formatter=format_openbookQA
),
)
def format_paws(ex, rng):
txt = f"Sentence 1: {ex['sentence1']} Sentence 2: {ex['sentence2']}"
hard_label = int(ex['label'])
return dict(txt=txt, hard_label=hard_label)

register_dataset(
"paws_labeled_final", # Unique name for the dataset registration.
DatasetConfig(
loader=hf_loader("paws", "labeled_final", split_names=dict(test="validation")),
formatter=format_paws
),
)

def format_glue_cola(ex, rng):
return dict(txt=ex['sentence'], hard_label=ex['label'])

register_dataset(
"glue_cola",
DatasetConfig(loader=hf_loader("glue", "cola"), formatter=format_glue_cola),
)

VALID_DATASETS: list[str] = list(_REGISTRY.keys())

"""
def format_mctaco(ex, rng):
sentence = ex['sentence']
question = ex['question']
answer = ex['answer']
label = int(ex['label'] == 'yes') # Convert 'yes'/'no' to binary label (1/0)
txt = f"Context: {sentence}\nQuestion: {question}\nAnswer: {answer}"
return dict(txt=txt, hard_label=label)

register_dataset(
"mc_taco",
DatasetConfig(
loader=hf_loader("mc_taco", split_names=dict(test="validation")), formatter=format_mctaco
),
)

from datasets import disable_caching
disable_caching()

Expand Down