Skip to content

Commit 1bef748

Browse files
authored
[doc][c10d] fixup fsdp tutorial (pytorch#1297)
Summary: Fix up the FSDP tutorial to get it functional again. 1. Add missing import for load_dataset. 2. Use `checkpoint` instead of `_shard.checkpoint` to get rid of a warning. 3. Add nlp to requirements.txt 4. Get rid of `load_metric` as this function does not exist in new `datasets` module. 5. Add `legacy=False` to get rid of tokenizer warnings. Test Plan: Ran the tutorial as follows and ensured that it ran successfully: ``` torchrun --nnodes=1 --nproc_per_node=2 T5_training.py W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] ***************************************** dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) dict_keys(['train', 'validation', 'test']) Size of train dataset: (157252, 3) Size of Validation dataset: (5599, 3) bFloat16 enabled for mixed precision - using bfSixteen policy ```
1 parent 47d0c2e commit 1bef748

File tree

5 files changed

+42
-41
lines changed

5 files changed

+42
-41
lines changed

distributed/FSDP/T5_training.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.nn.parallel import DistributedDataParallel as DDP
1515
from torch.utils.data.distributed import DistributedSampler
1616
from transformers.models.t5.modeling_t5 import T5Block
17+
from nlp import load_dataset
1718

1819
from torch.distributed.fsdp import (
1920
FullyShardedDataParallel as FSDP,
@@ -86,11 +87,11 @@ def fsdp_main(args):
8687
print("Size of train dataset: ", dataset['train'].shape)
8788
print("Size of Validation dataset: ", dataset['validation'].shape)
8889

89-
90+
9091
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
91-
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
92+
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
9293
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
93-
94+
9495
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
9596
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
9697

@@ -107,20 +108,20 @@ def fsdp_main(args):
107108

108109
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
109110
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
110-
111+
111112
torch.cuda.set_device(local_rank)
112-
113+
113114
# Set up FSDP parameters
114115
mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank)
115-
116+
116117
# Apply FSDP wrapping to the model
117118
model = FSDP(model,
118119
auto_wrap_policy=t5_auto_wrap_policy,
119120
mixed_precision=mixed_precision_policy,
120121
sharding_strategy=fsdp_config.sharding_strategy,
121122
device_id=torch.cuda.current_device(),
122123
limit_all_gathers=fsdp_config.limit_all_gathers)
123-
124+
124125
# Enabling this causes https://github.com/pytorch/examples/issues/1210
125126
if fsdp_config.fsdp_activation_checkpointing:
126127
policies.apply_fsdp_checkpointing(model)
@@ -150,7 +151,7 @@ def fsdp_main(args):
150151
if args.run_validation:
151152
curr_val_loss = validation(model, rank, world_size, val_loader)
152153
scheduler.step()
153-
154+
154155
if rank == 0:
155156

156157
print(f"--> epoch {epoch} completed...entering save and stats zone")
@@ -170,7 +171,7 @@ def fsdp_main(args):
170171
)
171172

172173
if train_config.save_model and curr_val_loss < best_val_loss:
173-
174+
174175
if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
175176
model_checkpointing.save_model_checkpoint(
176177
model, optimizer, rank, fsdp_config, epoch=1
@@ -183,7 +184,7 @@ def fsdp_main(args):
183184
if fsdp_config.save_optimizer:
184185
model_checkpointing.save_optimizer_checkpoint(
185186
model, optimizer, rank, fsdp_config, epoch=1
186-
)
187+
)
187188
if curr_val_loss < best_val_loss:
188189

189190
best_val_loss = curr_val_loss
@@ -212,5 +213,5 @@ def fsdp_main(args):
212213
args = parser.parse_args()
213214

214215
torch.manual_seed(args.seed)
215-
216+
216217
fsdp_main(args)

distributed/FSDP/model_checkpointing/checkpoint_handler.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
1212
)
1313

14-
from torch.distributed._shard.checkpoint import (
14+
from torch.distributed.checkpoint import (
1515
FileSystemReader,
1616
FileSystemWriter,
1717
save_state_dict,
@@ -24,7 +24,7 @@
2424

2525

2626
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
27-
import torch.distributed._shard.checkpoint as dist_cp
27+
import torch.distributed.checkpoint as dist_cp
2828
import torch.distributed as dist
2929

3030

@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
6565
if rank == 0:
6666
ck = checkpoint.keys()
6767
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
68-
68+
6969
dist_cp.load_state_dict(
7070
state_dict=checkpoint,
7171
storage_reader=reader,
@@ -108,7 +108,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
108108
state_dict=state_dict,
109109
storage_writer=distributed_writer,
110110
planner=DefaultSavePlanner(),
111-
111+
112112
)
113113
dist.barrier()
114114
t1 = time.perf_counter()
@@ -117,7 +117,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
117117
print(
118118
f"Checkpoint Time = {t1-t0:.4f}\n using {cfg.save_using_num_threads=} total threads"
119119
)
120-
120+
121121
def save_model_checkpoint(
122122
model,
123123
optimizer,
@@ -138,7 +138,7 @@ def save_model_checkpoint(
138138

139139
if cfg.verbose:
140140
print(f"saving process: rank {rank} done w model state_dict\n")
141-
141+
142142

143143
if rank == 0:
144144
print(f"--> saving model ...")
@@ -153,7 +153,7 @@ def save_model_checkpoint(
153153

154154
if cfg.verbose:
155155
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
156-
156+
157157

158158

159159
def load_model_checkpoint(model, rank, cfg, verbose=True):
@@ -299,7 +299,7 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
299299
StateDictType.LOCAL_STATE_DICT,
300300
):
301301
state_dict = model.state_dict()
302-
302+
303303

304304
# write out distributed checkpoint
305305
save_state_dict(state_dict, writer)

distributed/FSDP/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ datasets
33
tqdm
44
protobuf
55
SentencePiece
6+
nlp

distributed/FSDP/summarization_dataset.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import torch
1515
from torch.utils.data import Dataset, DataLoader
1616

17-
from datasets import load_dataset, load_metric
18-
17+
from nlp import load_dataset
1918

2019
from transformers import (
2120
AdamW,
@@ -25,59 +24,59 @@
2524
)
2625

2726
class wikihow(Dataset):
28-
def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
27+
def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
2928
self.dataset = load_dataset('wikihow', 'all', data_dir='data/', split=type_path)
3029
if num_samples:
3130
self.dataset = self.dataset.select(list(range(0, num_samples)))
3231
self.input_length = input_length
3332
self.tokenizer = tokenizer
3433
self.output_length = output_length
3534
self.print_text = print_text
36-
35+
3736
def __len__(self):
3837
return self.dataset.shape[0]
39-
38+
4039
def clean_text(self, text):
4140
text = text.replace('Example of text:', '')
4241
text = text.replace('Example of Summary:', '')
4342
text = text.replace('\n','')
4443
text = text.replace('``', '')
4544
text = text.replace('"', '')
46-
45+
4746
return text
48-
49-
47+
48+
5049
def convert_to_features(self, example_batch):
5150
# Tokenize contexts and questions (as pairs of inputs)
52-
51+
5352
if self.print_text:
5453
print("Input Text: ", self.clean_text(example_batch['text']))
5554
# input_ = self.clean_text(example_batch['text']) + " </s>"
5655
# target_ = self.clean_text(example_batch['headline']) + " </s>"
57-
56+
5857
input_ = self.clean_text(example_batch['text'])
5958
target_ = self.clean_text(example_batch['headline'])
60-
61-
source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
59+
60+
source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
6261
padding='max_length', truncation=True, return_tensors="pt")
63-
64-
targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
62+
63+
targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
6564
padding='max_length', truncation=True, return_tensors="pt")
66-
67-
65+
66+
6867
return source, targets
69-
68+
7069
def __getitem__(self, index):
7170
source, targets = self.convert_to_features(self.dataset[index])
72-
71+
7372
source_ids = source["input_ids"].squeeze()
7473
target_ids = targets["input_ids"].squeeze()
7574

7675
src_mask = source["attention_mask"].squeeze()
7776
target_mask = targets["attention_mask"].squeeze()
7877

7978
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
80-
79+
8180
def get_dataset(tokenizer, type_path, num_samples, args):
82-
return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length,
81+
return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length,
8382
output_length=max_output_length)

distributed/FSDP/utils/train_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler
3636
model.train()
3737
local_rank = int(os.environ['LOCAL_RANK'])
3838
fsdp_loss = torch.zeros(2).to(local_rank)
39-
39+
4040
if sampler:
4141
sampler.set_epoch(epoch)
4242
if rank==0:
@@ -98,5 +98,5 @@ def validation(model, rank, world_size, val_loader):
9898

9999
def setup_model(model_name):
100100
model = T5ForConditionalGeneration.from_pretrained(model_name)
101-
tokenizer = T5Tokenizer.from_pretrained(model_name)
101+
tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)
102102
return model, tokenizer

0 commit comments

Comments
 (0)