Skip to content

Commit 3b4d777

Browse files
committed
support yolov5s train on tpu
1 parent 597ff16 commit 3b4d777

File tree

3 files changed

+119
-31
lines changed

3 files changed

+119
-31
lines changed

models/common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,12 @@ def __init__(self, c1, c2, k=5):
333333
def forward(self, x):
334334
"""Processes input through a series of convolutions and max pooling operations for feature extraction."""
335335
x = self.cv1(x)
336-
with warnings.catch_warnings():
337-
warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning
338-
y1 = self.m(x)
339-
y2 = self.m(y1)
340-
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
336+
# wangxuec: We need to comment this out, otherwise we'll end up with a very fragmented portion of the captured graph
337+
# with warnings.catch_warnings():
338+
# warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning
339+
y1 = self.m(x)
340+
y2 = self.m(y1)
341+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
341342

342343

343344
class Focus(nn.Module):

train.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Datasets: https://github.com/ultralytics/yolov5/tree/master/data
1414
Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
1515
"""
16-
16+
import torch_tpu
1717
import argparse
1818
import math
1919
import os
@@ -99,14 +99,11 @@
9999
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
100100
GIT_INFO = check_git_info()
101101

102+
# from tpu_mlir import aot_backend
103+
from tpu_mlir.python.tools.train.tpu_mlir_jit import aot_backend
102104
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
103-
import torch.optim as optim
104-
from compile.FxGraphConvertor import fx2mlir
105-
import torchvision.models as models
106-
import argparse
107-
import numpy as np
108-
from torch.fx import Interpreter
109-
import torch._dynamo
105+
106+
import torch._dynamo.config
110107

111108
class JitNet(nn.Module):
112109
def __init__(self, net, loss_fn):
@@ -116,9 +113,10 @@ def __init__(self, net, loss_fn):
116113

117114
def forward(self, x, y):
118115
predict = self.net(x)
119-
loss,loss_item = self.loss_fn(self.net(x), y)
116+
loss,loss_item = self.loss_fn(predict, y)
120117
return loss, loss_item.detach()
121118

119+
122120
def _get_disc_decomp():
123121
from torch._decomp import get_decompositions
124122
aten = torch.ops.aten
@@ -144,6 +142,15 @@ def _get_disc_decomp():
144142
)
145143
return decompositions_dict
146144

145+
tensor_idx = 0
146+
features_out_hook = {}
147+
def hook(module, fea_in, fea_out):
148+
global features_out_hook, tensor_idx
149+
150+
if isinstance(fea_out, torch.Tensor):
151+
features_out_hook[f'f_{tensor_idx}'] = fea_out.detach().numpy()
152+
tensor_idx += 1
153+
return None
147154

148155
def convert_module_fx(
149156
submodule_name: str,
@@ -367,7 +374,7 @@ def lf(x):
367374
image_weights=opt.image_weights,
368375
quad=opt.quad,
369376
prefix=colorstr("train: "),
370-
shuffle=True,
377+
shuffle=False,
371378
seed=opt.seed,
372379
)
373380
labels = np.concatenate(dataset.labels, 0)
@@ -422,7 +429,12 @@ def lf(x):
422429
maps = np.zeros(nc) # mAP per class
423430
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
424431
scheduler.last_epoch = start_epoch - 1 # do not move
425-
scaler = torch.cuda.amp.GradScaler(enabled=amp)
432+
if opt.device == 'tpu':
433+
scaler = torch_tpu.tpu.amp.GradScaler(enabled=True, allow_fp16=True)
434+
elif opt.device == 'cpu':
435+
scaler = None
436+
else:
437+
scaler = torch.cuda.amp.GradScaler(enabled=amp)
426438
stopper, stop = EarlyStopping(patience=opt.patience), False
427439
compute_loss = ComputeLoss(model) # init loss class
428440
callbacks.run("on_train_start")
@@ -432,6 +444,15 @@ def lf(x):
432444
f"Logging results to {colorstr('bold', save_dir)}\n"
433445
f"Starting training for {epochs} epochs..."
434446
)
447+
448+
hook_handles = []
449+
# dump_cuda_ref = True
450+
dump_cuda_ref = False
451+
if dump_cuda_ref:
452+
for name, child in model.named_modules():
453+
hd = child.register_forward_hook(hook=hook)
454+
hook_handles.append(hd)
455+
435456
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
436457
callbacks.run("on_train_epoch_start")
437458
model.train()
@@ -454,7 +475,14 @@ def lf(x):
454475
if RANK in {-1, 0}:
455476
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
456477
optimizer.zero_grad()
457-
zwyjit = JitNet(model, compute_loss)
478+
compiled = True
479+
joint = False
480+
if compiled:
481+
if joint:
482+
zwyjit = JitNet(model, compute_loss)
483+
model_opt = torch.compile(zwyjit, backend=aot_backend, dynamic = None, fullgraph = False)
484+
else:
485+
model_opt = torch.compile(model, backend=aot_backend, dynamic = None, fullgraph = False)
458486
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
459487
callbacks.run("on_train_batch_start")
460488
ni = i + nb * epoch # number integrated batches (since train start)
@@ -481,22 +509,54 @@ def lf(x):
481509

482510
# Forward
483511
with torch.cuda.amp.autocast(amp):
484-
# print(1)
485-
# zwy = SophonJointCompile(model, [imgs, targets], trace_joint=True, output_loss_index=0, args=None)
486-
# pred = model(imgs) # forward
487-
# loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
488-
# loss, loss_items = zwyjit(imgs, targets.to(device))
489-
fx_g, signature = aot_export_module(
490-
zwyjit, [imgs, targets], trace_joint=True, output_loss_index=0, decompositions=_get_disc_decomp()
491-
)
492-
print(fx_g)
512+
# from torch._subclasses import FakeTensorMode
513+
# fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
514+
# with fake_mode:
515+
# zwyjit = JitNet(model, compute_loss)
516+
# fx_g, signature = aot_export_module(
517+
# zwyjit, [imgs, targets], trace_joint=True, output_loss_index=0, decompositions=_get_disc_decomp()
518+
# )
519+
# print('fx_g:')
520+
# fx_g.graph.print_tabular()
521+
# print('signature:', signature)
522+
# exit(0)
523+
if dump_cuda_ref:
524+
pred = model(imgs)
525+
global features_out_hook
526+
features_out_hook['data'] = imgs.detach().numpy()
527+
for name, param in model.named_parameters():
528+
features_out_hook[name] = param.detach().numpy()
529+
np.savez('layer_outputs.npz', **features_out_hook)
530+
for hd in hook_handles:
531+
hd.remove()
532+
exit(0)
533+
else:
534+
if compiled:
535+
if joint:
536+
loss, loss_items = model_opt(imgs, targets.to(device)) # forward
537+
else:
538+
pred = model_opt(imgs)
539+
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
540+
else:
541+
pred = model(imgs)
542+
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
543+
493544
if RANK != -1:
494545
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
495546
if opt.quad:
496547
loss *= 4.0
497548

498549
# Backward
499-
scaler.scale(loss).backward()
550+
if opt.device == 'tpu':
551+
loss = loss.to(device)
552+
#print('old loss:', loss, loss.device)
553+
#total_loss = 0.2666
554+
#loss.data.copy_(total_loss)
555+
print('loss:', loss, loss.device)
556+
if scaler is None:
557+
loss.backward()
558+
else:
559+
scaler.scale(loss).backward()
500560

501561
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
502562
if ni - last_opt_step >= accumulate:
@@ -663,7 +723,7 @@ def parse_opt(known=False):
663723
parser.add_argument("--bucket", type=str, default="", help="gsutil bucket")
664724
parser.add_argument("--cache", type=str, nargs="?", const="ram", help="image --cache ram/disk")
665725
parser.add_argument("--image-weights", action="store_true", help="use weighted image selection for training")
666-
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
726+
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu or tpu")
667727
parser.add_argument("--multi-scale", action="store_true", help="vary img-size +/- 50%%")
668728
parser.add_argument("--single-cls", action="store_true", help="train multi-class data as single-class")
669729
parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "AdamW"], default="SGD", help="optimizer")
@@ -690,6 +750,7 @@ def parse_opt(known=False):
690750
# NDJSON logging
691751
parser.add_argument("--ndjson-console", action="store_true", help="Log ndjson to console")
692752
parser.add_argument("--ndjson-file", action="store_true", help="Log ndjson to file")
753+
parser.add_argument("--debug_cmd", type=str, default="", help="debug_cmd")
693754

694755
return parser.parse_known_args()[0] if known else parser.parse_args()
695756

@@ -1058,7 +1119,26 @@ def run(**kwargs):
10581119
main(opt)
10591120
return opt
10601121

1122+
# import torch._dynamo
1123+
# import logging
1124+
# logger = logging.getLogger("torch._dynamo")
1125+
# logger.setLevel(logging.DEBUG)
1126+
# console_handler = logging.StreamHandler()
1127+
# console_handler.setLevel(logging.DEBUG)
1128+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
1129+
# console_handler.setFormatter(formatter)
1130+
# logger.addHandler(console_handler)
10611131

10621132
if __name__ == "__main__":
10631133
opt = parse_opt()
1134+
#print_ori_fx_graph/dump_fx_graph/skip_tpu_compile/dump_bmodel_input
1135+
import tpu_mlir
1136+
tpu_mlir.python.tools.train.config.debug_cmd = opt.debug_cmd
1137+
tpu_mlir.python.tools.train.config.compile_opt = 2
1138+
# tpu_mlir.python.tools.train.config.only_compile_graph_id = 1
1139+
# tpu_mlir.python.tools.train.config.run_on_cmodel = False if opt.device == 'tpu' else True
1140+
tpu_mlir.python.tools.train.config.run_on_cmodel = True
1141+
tpu_mlir.python.tools.train.config.print_config_info()
1142+
# torch._dynamo.config.suppress_errors = True
1143+
10641144
main(opt)

utils/torch_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,17 @@ def select_device(device="", batch_size=0, newline=True):
116116
s = f"YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} "
117117
device = str(device).strip().lower().replace("cuda:", "").replace("none", "") # to string, 'cuda:0' to '0'
118118
cpu = device == "cpu"
119+
tpu = device == "tpu"
119120
mps = device == "mps" # Apple Metal Performance Shaders (MPS)
120-
if cpu or mps:
121+
if cpu or mps or tpu:
121122
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
122123
elif device: # non-cpu device requested
123124
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
124125
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", "")), (
125126
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
126127
)
127128

128-
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
129+
if not cpu and not mps and not tpu and torch.cuda.is_available(): # prefer GPU if available
129130
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
130131
n = len(devices) # device count
131132
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
@@ -138,6 +139,9 @@ def select_device(device="", batch_size=0, newline=True):
138139
elif mps and getattr(torch, "has_mps", False) and torch.backends.mps.is_available(): # prefer MPS if available
139140
s += "MPS\n"
140141
arg = "mps"
142+
elif tpu:
143+
s += "TPU\n"
144+
arg = "tpu"
141145
else: # revert to CPU
142146
s += "CPU\n"
143147
arg = "cpu"
@@ -457,7 +461,10 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
457461
"""Initializes EMA with model parameters, decay rate, tau for decay adjustment, and update count; sets model to
458462
evaluation mode.
459463
"""
460-
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
464+
# self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
465+
self.ema = de_parallel(model) # FP32 EMA
466+
self.ema.load_state_dict(model.state_dict())
467+
self.ema = self.ema.eval()
461468
self.updates = updates # number of EMA updates
462469
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
463470
for p in self.ema.parameters():

0 commit comments

Comments
 (0)