generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning #255
Merged
Merged
Finetuning #255
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
d72d4ac
temp commit
AkshitaB 3f08572
move_to_device should work for UserDict too
AkshitaB 9a07a1b
works
AkshitaB f7fc850
clean up
AkshitaB 77ade5f
run generation with model
AkshitaB 71c485c
causal lm
AkshitaB 747893f
Merge branch 'main' into finetuning
AkshitaB ded17f4
change label
AkshitaB 613a744
Merge branch 'main' into finetuning
AkshitaB 87ff48e
single step finetune
AkshitaB ff1e6a3
docstrings, tests, cleanup
AkshitaB 5bb6a72
Merge branch 'main' into finetuning
AkshitaB bfa8b24
fix bug with num tokens
AkshitaB dbfe36e
update changelog
AkshitaB 0fe64a8
fix test
AkshitaB b04c8ff
test with different model
AkshitaB 2ace306
simplify
AkshitaB da291db
limit loss calculation to actual labels
AkshitaB c247c48
address comments
AkshitaB 92d84c7
Merge branch 'main' into finetuning
AkshitaB d9033b7
Merge branch 'main' into finetuning
dirkgr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
################## | ||
# Model settings # | ||
################## | ||
|
||
local pretrained_model = "t5-base"; | ||
local load_with_low_cpu_mem_usage = false; | ||
|
||
local modules_to_wrap = ["[a-zA-Z_.]+\\.[0-9]+"]; # TODO: works for t5 and gpt2. confirm with other models too. | ||
|
||
#################### | ||
# Trainer settings # | ||
#################### | ||
|
||
# Trainer settings, adjust to your use-case. | ||
local training_steps = 20; # total number of optimization steps to train for | ||
local validate_every = 5; # how often to validate and save checkpoints | ||
|
||
local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) | ||
local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) | ||
# This is the batch size per GPU, ignoring gradient accumulation: | ||
local batch_size = 2; | ||
# So the effective batch size is `batch_size * grad_accum * devices` | ||
|
||
local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) | ||
local amp = false; # use PyTorch's native automatic mixed precision | ||
local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) | ||
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. | ||
|
||
###################### | ||
# Optimizer settings # | ||
###################### | ||
|
||
local warmup_steps = 20; | ||
local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" | ||
|
||
|
||
assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; | ||
|
||
# FullyShardedDataParallel config: | ||
local fsdp_config = if fsdp then { | ||
reshard_after_forward: true, | ||
move_params_to_cpu: cpu_offloading, | ||
move_grads_to_cpu: cpu_offloading, | ||
mixed_precision: amp, | ||
} else null; | ||
|
||
local training_engine = { | ||
type: if fsdp then "fairscale" else "torch", | ||
optimizer: { | ||
type: "torch::AdamW", | ||
lr: learning_rate, | ||
betas: [0.9, 0.95], | ||
eps: 1e-6, | ||
}, | ||
lr_scheduler: { | ||
type: "transformers::linear", | ||
num_warmup_steps: warmup_steps, | ||
num_training_steps: training_steps, | ||
}, | ||
amp: amp, | ||
[if fsdp then "fsdp_config" else null]: fsdp_config, | ||
}; | ||
|
||
local distributed_dataloader = { | ||
batch_size: batch_size, | ||
sampler: { | ||
type: "torch::DistributedSampler", | ||
shuffle: true, | ||
drop_last: true, | ||
}, | ||
}; | ||
|
||
local single_device_dataloader = { | ||
shuffle: true, | ||
batch_size: batch_size, | ||
}; | ||
|
||
local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; | ||
|
||
{ | ||
steps: { | ||
raw_data: { | ||
type: "datasets::load", | ||
path: "snli", | ||
}, | ||
/*"subset_data": { | ||
type: "subset-data", | ||
data: { type: "ref", ref: "raw_data" }, | ||
max_samples: 10, | ||
},*/ | ||
processed_data: { | ||
type: "snli-text2text", | ||
data: { type: "ref", ref: "raw_data" }, | ||
}, | ||
trained_model: { | ||
type: "transformers::finetune", | ||
model: { | ||
type: "fairscale::with_wrapped_modules", | ||
model: { | ||
type: "transformers::finetune::from_pretrained", | ||
pretrained_model_name_or_path: pretrained_model, | ||
low_cpu_mem_usage: load_with_low_cpu_mem_usage, | ||
}, | ||
modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually | ||
fsdp_config: fsdp_config, | ||
activation_checkpointing: activation_checkpointing, | ||
}, | ||
tokenizer: { | ||
pretrained_model_name_or_path: pretrained_model | ||
}, | ||
dataset_dict: { type: "ref", ref: "processed_data" }, | ||
train_dataloader: dataloader, | ||
validation_split: "validation", | ||
grad_accum: grad_accum, | ||
train_steps: training_steps, | ||
validate_every: validate_every, | ||
checkpoint_every: validate_every, | ||
log_every: 1, | ||
device_count: devices, | ||
training_engine: training_engine, | ||
}, | ||
generations: { | ||
type: "transformers::run_generation_dataset", | ||
max_length: 5, | ||
input: {"type": "ref", "ref": "processed_data"}, | ||
batch_size: batch_size, | ||
model: {"type": "ref", "ref": "trained_model"}, | ||
prompt_field: "source", | ||
output_field: "generation", | ||
splits: ["validation"] | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Union | ||
|
||
import datasets as ds | ||
|
||
from tango.integrations.datasets import DatasetsFormat | ||
from tango.step import Step | ||
|
||
|
||
@Step.register("subset-data") | ||
class SubsetData(Step): | ||
""" | ||
Creates a subset of the data; mostly to be used for testing/debugging. | ||
""" | ||
|
||
DETERMINISTIC = True | ||
CACHEABLE = True | ||
VERSION = "001" | ||
|
||
FORMAT = DatasetsFormat() | ||
|
||
def run( # type: ignore | ||
self, | ||
data: Union[ds.DatasetDict, ds.Dataset], | ||
max_samples: int = 5, | ||
) -> Union[ds.DatasetDict, ds.Dataset]: | ||
""" | ||
Returns a copy of the `data` with number of samples limited to `max_samples` for | ||
each split. | ||
|
||
:param data: | ||
The dataset or dataset dict object. | ||
:param max_samples: | ||
The maximum number of samples to return per split. | ||
""" | ||
|
||
# Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`. | ||
def filter_fn(example, indices): | ||
return indices < max_samples | ||
|
||
return data.filter(filter_fn, with_indices=True) | ||
|
||
|
||
@Step.register("snli-text2text") | ||
class SnliText2Text(Step): | ||
""" | ||
Converts the snli dataset to a text-to-text format. | ||
|
||
Examples | ||
-------- | ||
|
||
original_instance = { | ||
"premise": "Two cats are sitting on a wall.", | ||
"hypothesis": "The cats are chasing a mouse.", | ||
"label": 2 # contradiction | ||
} | ||
|
||
returned_instance = { | ||
"source": "nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: " | ||
"target": "contradiction" | ||
} | ||
|
||
""" | ||
|
||
DETERMINISTIC = True | ||
CACHEABLE = True | ||
VERSION = "001" | ||
|
||
FORMAT = DatasetsFormat() | ||
|
||
def run( # type: ignore | ||
self, | ||
data: Union[ds.DatasetDict, ds.Dataset], | ||
source_prefix: str = "nli", | ||
premise_prefix: str = "premise", | ||
hypothesis_prefix: str = "hypothesis", | ||
label_prefix: str = "label", | ||
num_workers: int = 1, | ||
) -> Union[ds.DatasetDict, ds.Dataset]: | ||
""" | ||
:param data: | ||
The snli `Dataset` or `DatasetDict` object. | ||
:param source_prefix: | ||
The str to add before the start of the source sequence. | ||
:param premise_prefix: | ||
The str to add before the start of the `premise` in the source sequence. | ||
:param hypothesis_prefix: | ||
The str to add before the start of the `hypothesis` in the source sequence. | ||
:param label_prefix: | ||
The str to add as the prompt for the label. | ||
:param num_workers: | ||
The number of workers to use for processing the data. | ||
""" | ||
|
||
def filter_no_gold(example, indices): | ||
if example["label"] == -1: | ||
return False | ||
return True | ||
|
||
data = data.filter(filter_no_gold, with_indices=True) | ||
|
||
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} | ||
|
||
def _mapper(example): | ||
return { | ||
"source": ( | ||
f'{source_prefix} {premise_prefix}: {example["premise"]} ' | ||
f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: ' | ||
), | ||
"target": f'{label_map[example["label"]]}', | ||
} | ||
|
||
if isinstance(data, ds.Dataset): | ||
old_cols = data.column_names | ||
else: | ||
old_cols = list(data.column_names.values())[0] | ||
|
||
dataset = data.map( | ||
_mapper, | ||
batched=False, | ||
num_proc=num_workers, | ||
remove_columns=old_cols, # remove all old columns | ||
desc="Converting data to text-to-text format", | ||
) | ||
|
||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import datasets as ds | ||
import pytest | ||
|
||
from tango.common import Params | ||
from tango.common.testing import TangoTestCase, run_experiment | ||
|
||
|
||
class TestFinetuneSNLI(TangoTestCase): | ||
@pytest.mark.parametrize( | ||
"model, model_type", | ||
[("patrickvonplaten/t5-tiny-random", "t5"), ("sshleifer/tiny-gpt2", "gpt2")], | ||
) | ||
def test_config(self, model: str, model_type: str): | ||
overrides = { | ||
"steps.trained_model.model.model.pretrained_model_name_or_path": model, | ||
"steps.trained_model.tokenizer.pretrained_model_name_or_path": model, | ||
"steps.subset_data": { | ||
"type": "subset-data", | ||
"data": {"type": "ref", "ref": "raw_data"}, | ||
"max_samples": 10, | ||
}, | ||
"steps.processed_data.data.ref": "subset_data", | ||
} | ||
config = Params.from_file("config.jsonnet", params_overrides=overrides) | ||
# Make sure we've overrode the model entirely. | ||
flattened = config.as_flat_dict() | ||
for key, value in flattened.items(): | ||
if "model_name" in key or (isinstance(value, str) and model_type in value): | ||
assert value == model | ||
|
||
with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: | ||
assert (run_dir / "processed_data").is_dir() | ||
processed = ds.load_from_disk(run_dir / "processed_data" / "data") | ||
assert len(processed["train"][0].keys()) == 2 | ||
assert "source" in processed["train"][0].keys() | ||
assert "target" in processed["train"][0].keys() | ||
assert processed["train"][0]["source"].startswith("nli premise:") | ||
|
||
assert (run_dir / "trained_model").is_dir() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have the
DatasetRemix
step for Tango'sDatasetDict
. Can we have the same for HF's datasets?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will work on it separately: #268 This is technically unrelated to finetuning.