Skip to content

Commit 1083049

Browse files
AkshitaBdirkgr
andauthored
Finetuning (#255)
* temp commit * move_to_device should work for UserDict too * works * clean up * run generation with model * causal lm * change label * single step finetune * docstrings, tests, cleanup * fix bug with num tokens * update changelog * fix test * test with different model * simplify * limit loss calculation to actual labels * address comments Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent 42b1dba commit 1083049

File tree

14 files changed

+1027
-57
lines changed

14 files changed

+1027
-57
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- `StepGraph` now prints itself in a readable way.
1919
- Tango now automatically detects when it's running under a debugger, and disables multicore support accordingly. Many debuggers can't properly follow sub-processes, so this is a convenience for people who love debuggers.
2020
- Added more models to the stuff we can import from the transformers library.
21+
- Added new example for finetuning text-to-text models.
2122

2223
### Changed
2324

2425
- Renamed `click_logger` to `cli_logger`, and we now use [rich](https://github.com/Textualize/rich)'s logging `Handler` as the default handler, which means prettier output, better tracebacks, and you can use rich's markup syntax with the `cli_logger` to easily add style to text.
2526
- Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`.
2627
- `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures.
2728
- Upgraded PyTorch version in `tango` Docker image to latest `v1.11.0+cu113`.
29+
- `RunGeneration` now allows model object as input.
2830

2931
### Fixed
3032

examples/eval_p3/config.jsonnet

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ local dataset_steps = std.foldl(
3030
"max_length": 200,
3131
"input": {"ref": "dataset_" + dataset_name},
3232
"batch_size": batch_size,
33-
"model_name": model,
33+
"model": model,
3434
"prompt_field": "inputs_pretokenized",
3535
"output_field": "generation",
3636
"splits": ["validation"]

examples/finetune/__init__.py

Whitespace-only changes.

examples/finetune/config.jsonnet

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
##################
2+
# Model settings #
3+
##################
4+
5+
local pretrained_model = "t5-base";
6+
local load_with_low_cpu_mem_usage = false;
7+
8+
local modules_to_wrap = ["[a-zA-Z_.]+\\.[0-9]+"]; # TODO: works for t5 and gpt2. confirm with other models too.
9+
10+
####################
11+
# Trainer settings #
12+
####################
13+
14+
# Trainer settings, adjust to your use-case.
15+
local training_steps = 20; # total number of optimization steps to train for
16+
local validate_every = 5; # how often to validate and save checkpoints
17+
18+
local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU)
19+
local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size)
20+
# This is the batch size per GPU, ignoring gradient accumulation:
21+
local batch_size = 2;
22+
# So the effective batch size is `batch_size * grad_accum * devices`
23+
24+
local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)
25+
local amp = false; # use PyTorch's native automatic mixed precision
26+
local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
27+
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.
28+
29+
######################
30+
# Optimizer settings #
31+
######################
32+
33+
local warmup_steps = 20;
34+
local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2"
35+
36+
37+
assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp";
38+
39+
# FullyShardedDataParallel config:
40+
local fsdp_config = if fsdp then {
41+
reshard_after_forward: true,
42+
move_params_to_cpu: cpu_offloading,
43+
move_grads_to_cpu: cpu_offloading,
44+
mixed_precision: amp,
45+
} else null;
46+
47+
local training_engine = {
48+
type: if fsdp then "fairscale" else "torch",
49+
optimizer: {
50+
type: "torch::AdamW",
51+
lr: learning_rate,
52+
betas: [0.9, 0.95],
53+
eps: 1e-6,
54+
},
55+
lr_scheduler: {
56+
type: "transformers::linear",
57+
num_warmup_steps: warmup_steps,
58+
num_training_steps: training_steps,
59+
},
60+
amp: amp,
61+
[if fsdp then "fsdp_config" else null]: fsdp_config,
62+
};
63+
64+
local distributed_dataloader = {
65+
batch_size: batch_size,
66+
sampler: {
67+
type: "torch::DistributedSampler",
68+
shuffle: true,
69+
drop_last: true,
70+
},
71+
};
72+
73+
local single_device_dataloader = {
74+
shuffle: true,
75+
batch_size: batch_size,
76+
};
77+
78+
local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;
79+
80+
{
81+
steps: {
82+
raw_data: {
83+
type: "datasets::load",
84+
path: "snli",
85+
},
86+
/*"subset_data": {
87+
type: "subset-data",
88+
data: { type: "ref", ref: "raw_data" },
89+
max_samples: 10,
90+
},*/
91+
processed_data: {
92+
type: "snli-text2text",
93+
data: { type: "ref", ref: "raw_data" },
94+
},
95+
trained_model: {
96+
type: "transformers::finetune",
97+
model: {
98+
type: "fairscale::with_wrapped_modules",
99+
model: {
100+
type: "transformers::finetune::from_pretrained",
101+
pretrained_model_name_or_path: pretrained_model,
102+
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
103+
},
104+
modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually
105+
fsdp_config: fsdp_config,
106+
activation_checkpointing: activation_checkpointing,
107+
},
108+
tokenizer: {
109+
pretrained_model_name_or_path: pretrained_model
110+
},
111+
dataset_dict: { type: "ref", ref: "processed_data" },
112+
train_dataloader: dataloader,
113+
validation_split: "validation",
114+
grad_accum: grad_accum,
115+
train_steps: training_steps,
116+
validate_every: validate_every,
117+
checkpoint_every: validate_every,
118+
log_every: 1,
119+
device_count: devices,
120+
training_engine: training_engine,
121+
},
122+
generations: {
123+
type: "transformers::run_generation_dataset",
124+
max_length: 5,
125+
input: {"type": "ref", "ref": "processed_data"},
126+
batch_size: batch_size,
127+
model: {"type": "ref", "ref": "trained_model"},
128+
prompt_field: "source",
129+
output_field: "generation",
130+
splits: ["validation"]
131+
}
132+
}
133+
}

examples/finetune/snli_steps.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Union
2+
3+
import datasets as ds
4+
5+
from tango.integrations.datasets import DatasetsFormat
6+
from tango.step import Step
7+
8+
9+
@Step.register("subset-data")
10+
class SubsetData(Step):
11+
"""
12+
Creates a subset of the data; mostly to be used for testing/debugging.
13+
"""
14+
15+
DETERMINISTIC = True
16+
CACHEABLE = True
17+
VERSION = "001"
18+
19+
FORMAT = DatasetsFormat()
20+
21+
def run( # type: ignore
22+
self,
23+
data: Union[ds.DatasetDict, ds.Dataset],
24+
max_samples: int = 5,
25+
) -> Union[ds.DatasetDict, ds.Dataset]:
26+
"""
27+
Returns a copy of the `data` with number of samples limited to `max_samples` for
28+
each split.
29+
30+
:param data:
31+
The dataset or dataset dict object.
32+
:param max_samples:
33+
The maximum number of samples to return per split.
34+
"""
35+
36+
# Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`.
37+
def filter_fn(example, indices):
38+
return indices < max_samples
39+
40+
return data.filter(filter_fn, with_indices=True)
41+
42+
43+
@Step.register("snli-text2text")
44+
class SnliText2Text(Step):
45+
"""
46+
Converts the snli dataset to a text-to-text format.
47+
48+
Examples
49+
--------
50+
51+
original_instance = {
52+
"premise": "Two cats are sitting on a wall.",
53+
"hypothesis": "The cats are chasing a mouse.",
54+
"label": 2 # contradiction
55+
}
56+
57+
returned_instance = {
58+
"source": "nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: "
59+
"target": "contradiction"
60+
}
61+
62+
"""
63+
64+
DETERMINISTIC = True
65+
CACHEABLE = True
66+
VERSION = "001"
67+
68+
FORMAT = DatasetsFormat()
69+
70+
def run( # type: ignore
71+
self,
72+
data: Union[ds.DatasetDict, ds.Dataset],
73+
source_prefix: str = "nli",
74+
premise_prefix: str = "premise",
75+
hypothesis_prefix: str = "hypothesis",
76+
label_prefix: str = "label",
77+
num_workers: int = 1,
78+
) -> Union[ds.DatasetDict, ds.Dataset]:
79+
"""
80+
:param data:
81+
The snli `Dataset` or `DatasetDict` object.
82+
:param source_prefix:
83+
The str to add before the start of the source sequence.
84+
:param premise_prefix:
85+
The str to add before the start of the `premise` in the source sequence.
86+
:param hypothesis_prefix:
87+
The str to add before the start of the `hypothesis` in the source sequence.
88+
:param label_prefix:
89+
The str to add as the prompt for the label.
90+
:param num_workers:
91+
The number of workers to use for processing the data.
92+
"""
93+
94+
def filter_no_gold(example, indices):
95+
if example["label"] == -1:
96+
return False
97+
return True
98+
99+
data = data.filter(filter_no_gold, with_indices=True)
100+
101+
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
102+
103+
def _mapper(example):
104+
return {
105+
"source": (
106+
f'{source_prefix} {premise_prefix}: {example["premise"]} '
107+
f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: '
108+
),
109+
"target": f'{label_map[example["label"]]}',
110+
}
111+
112+
if isinstance(data, ds.Dataset):
113+
old_cols = data.column_names
114+
else:
115+
old_cols = list(data.column_names.values())[0]
116+
117+
dataset = data.map(
118+
_mapper,
119+
batched=False,
120+
num_proc=num_workers,
121+
remove_columns=old_cols, # remove all old columns
122+
desc="Converting data to text-to-text format",
123+
)
124+
125+
return dataset

examples/finetune/test.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import datasets as ds
2+
import pytest
3+
4+
from tango.common import Params
5+
from tango.common.testing import TangoTestCase, run_experiment
6+
7+
8+
class TestFinetuneSNLI(TangoTestCase):
9+
@pytest.mark.parametrize(
10+
"model, model_type",
11+
[("patrickvonplaten/t5-tiny-random", "t5"), ("sshleifer/tiny-gpt2", "gpt2")],
12+
)
13+
def test_config(self, model: str, model_type: str):
14+
overrides = {
15+
"steps.trained_model.model.model.pretrained_model_name_or_path": model,
16+
"steps.trained_model.tokenizer.pretrained_model_name_or_path": model,
17+
"steps.subset_data": {
18+
"type": "subset-data",
19+
"data": {"type": "ref", "ref": "raw_data"},
20+
"max_samples": 10,
21+
},
22+
"steps.processed_data.data.ref": "subset_data",
23+
}
24+
config = Params.from_file("config.jsonnet", params_overrides=overrides)
25+
# Make sure we've overrode the model entirely.
26+
flattened = config.as_flat_dict()
27+
for key, value in flattened.items():
28+
if "model_name" in key or (isinstance(value, str) and model_type in value):
29+
assert value == model
30+
31+
with run_experiment(config, include_package=["snli_steps.py"]) as run_dir:
32+
assert (run_dir / "processed_data").is_dir()
33+
processed = ds.load_from_disk(run_dir / "processed_data" / "data")
34+
assert len(processed["train"][0].keys()) == 2
35+
assert "source" in processed["train"][0].keys()
36+
assert "target" in processed["train"][0].keys()
37+
assert processed["train"][0]["source"].startswith("nli premise:")
38+
39+
assert (run_dir / "trained_model").is_dir()

tango/common/lazy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ def construct(self, **kwargs) -> T:
8383
"""
8484
# If there are duplicate keys between self._constructor_extras and kwargs,
8585
# this will overwrite the ones in self._constructor_extras with what's in kwargs.
86-
contructor_kwargs = {**self._constructor_extras, **kwargs}
87-
return self.constructor(**contructor_kwargs)
86+
constructor_kwargs = {**self._constructor_extras, **kwargs}
87+
return self.constructor(**constructor_kwargs)

tango/integrations/torch/training_engine.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .model import Model
1414
from .optim import LRScheduler, Optimizer
1515
from .train_config import TrainConfig
16+
from .util import move_to_device
1617

1718

1819
class TrainingEngine(Registrable):
@@ -60,19 +61,6 @@ def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRSchedule
6061
lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer)
6162
return lr_scheduler
6263

63-
@classmethod
64-
def _move_to_device(cls, o: Any, device: torch.device) -> Any:
65-
if isinstance(o, torch.Tensor):
66-
return o.to(device)
67-
elif isinstance(o, dict):
68-
return {k: cls._move_to_device(v, device) for k, v in o.items()}
69-
elif isinstance(o, list):
70-
return [cls._move_to_device(x, device) for x in o]
71-
elif isinstance(o, tuple):
72-
return tuple((cls._move_to_device(x, device) for x in o))
73-
else:
74-
return o
75-
7664
@abstractmethod
7765
def forward_train(
7866
self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int
@@ -207,7 +195,7 @@ def forward_train(
207195
self.optimizer.zero_grad(set_to_none=True)
208196

209197
# Move tensors to right device.
210-
micro_batch = self._move_to_device(micro_batch, self.device)
198+
micro_batch = move_to_device(micro_batch, self.device)
211199

212200
with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):
213201
outputs = self.model(**micro_batch)
@@ -217,7 +205,7 @@ def forward_train(
217205

218206
def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]:
219207
# Move tensors to right device.
220-
batch = self._move_to_device(batch, self.device)
208+
batch = move_to_device(batch, self.device)
221209

222210
with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):
223211
with torch.inference_mode():

tango/integrations/torch/util.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random
22
import warnings
3+
from collections import UserDict
34
from typing import Dict, Optional, TypeVar, Union
45

56
import numpy as np
@@ -15,7 +16,7 @@
1516
def move_to_device(o: T, device: torch.device) -> T:
1617
if isinstance(o, torch.Tensor):
1718
return o.to(device) # type: ignore[return-value]
18-
elif isinstance(o, dict):
19+
elif isinstance(o, dict) or isinstance(o, UserDict):
1920
return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
2021
elif isinstance(o, list):
2122
return [move_to_device(x, device) for x in o] # type: ignore[return-value]

0 commit comments

Comments
 (0)