1313Datasets: https://github.com/ultralytics/yolov5/tree/master/data
1414Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
1515"""
16-
16+ import torch_tpu
1717import argparse
1818import math
1919import os
9999WORLD_SIZE = int (os .getenv ("WORLD_SIZE" , 1 ))
100100GIT_INFO = check_git_info ()
101101
102+ # from tpu_mlir import aot_backend
103+ from tpu_mlir .python .tools .train .tpu_mlir_jit import aot_backend
102104from torch ._functorch .aot_autograd import aot_export_joint_simple , aot_export_module
103- import torch .optim as optim
104- from compile .FxGraphConvertor import fx2mlir
105- import torchvision .models as models
106- import argparse
107- import numpy as np
108- from torch .fx import Interpreter
109- import torch ._dynamo
105+
106+ import torch ._dynamo .config
110107
111108class JitNet (nn .Module ):
112109 def __init__ (self , net , loss_fn ):
@@ -116,9 +113,10 @@ def __init__(self, net, loss_fn):
116113
117114 def forward (self , x , y ):
118115 predict = self .net (x )
119- loss ,loss_item = self .loss_fn (self . net ( x ) , y )
116+ loss ,loss_item = self .loss_fn (predict , y )
120117 return loss , loss_item .detach ()
121118
119+
122120def _get_disc_decomp ():
123121 from torch ._decomp import get_decompositions
124122 aten = torch .ops .aten
@@ -144,6 +142,15 @@ def _get_disc_decomp():
144142 )
145143 return decompositions_dict
146144
145+ tensor_idx = 0
146+ features_out_hook = {}
147+ def hook (module , fea_in , fea_out ):
148+ global features_out_hook , tensor_idx
149+
150+ if isinstance (fea_out , torch .Tensor ):
151+ features_out_hook [f'f_{ tensor_idx } ' ] = fea_out .detach ().numpy ()
152+ tensor_idx += 1
153+ return None
147154
148155def convert_module_fx (
149156 submodule_name : str ,
@@ -367,7 +374,7 @@ def lf(x):
367374 image_weights = opt .image_weights ,
368375 quad = opt .quad ,
369376 prefix = colorstr ("train: " ),
370- shuffle = True ,
377+ shuffle = False ,
371378 seed = opt .seed ,
372379 )
373380 labels = np .concatenate (dataset .labels , 0 )
@@ -422,7 +429,12 @@ def lf(x):
422429 maps = np .zeros (nc ) # mAP per class
423430 results = (
0 ,
0 ,
0 ,
0 ,
0 ,
0 ,
0 )
# P, R, [email protected] , [email protected] , val_loss(box, obj, cls) 424431 scheduler .last_epoch = start_epoch - 1 # do not move
425- scaler = torch .cuda .amp .GradScaler (enabled = amp )
432+ if opt .device == 'tpu' :
433+ scaler = torch_tpu .tpu .amp .GradScaler (enabled = True , allow_fp16 = True )
434+ elif opt .device == 'cpu' :
435+ scaler = None
436+ else :
437+ scaler = torch .cuda .amp .GradScaler (enabled = amp )
426438 stopper , stop = EarlyStopping (patience = opt .patience ), False
427439 compute_loss = ComputeLoss (model ) # init loss class
428440 callbacks .run ("on_train_start" )
@@ -432,6 +444,15 @@ def lf(x):
432444 f"Logging results to { colorstr ('bold' , save_dir )} \n "
433445 f"Starting training for { epochs } epochs..."
434446 )
447+
448+ hook_handles = []
449+ # dump_cuda_ref = True
450+ dump_cuda_ref = False
451+ if dump_cuda_ref :
452+ for name , child in model .named_modules ():
453+ hd = child .register_forward_hook (hook = hook )
454+ hook_handles .append (hd )
455+
435456 for epoch in range (start_epoch , epochs ): # epoch ------------------------------------------------------------------
436457 callbacks .run ("on_train_epoch_start" )
437458 model .train ()
@@ -454,7 +475,14 @@ def lf(x):
454475 if RANK in {- 1 , 0 }:
455476 pbar = tqdm (pbar , total = nb , bar_format = TQDM_BAR_FORMAT ) # progress bar
456477 optimizer .zero_grad ()
457- zwyjit = JitNet (model , compute_loss )
478+ compiled = True
479+ joint = False
480+ if compiled :
481+ if joint :
482+ zwyjit = JitNet (model , compute_loss )
483+ model_opt = torch .compile (zwyjit , backend = aot_backend , dynamic = None , fullgraph = False )
484+ else :
485+ model_opt = torch .compile (model , backend = aot_backend , dynamic = None , fullgraph = False )
458486 for i , (imgs , targets , paths , _ ) in pbar : # batch -------------------------------------------------------------
459487 callbacks .run ("on_train_batch_start" )
460488 ni = i + nb * epoch # number integrated batches (since train start)
@@ -481,22 +509,54 @@ def lf(x):
481509
482510 # Forward
483511 with torch .cuda .amp .autocast (amp ):
484- # print(1)
485- # zwy = SophonJointCompile(model, [imgs, targets], trace_joint=True, output_loss_index=0, args=None)
486- # pred = model(imgs) # forward
487- # loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
488- # loss, loss_items = zwyjit(imgs, targets.to(device))
489- fx_g , signature = aot_export_module (
490- zwyjit , [imgs , targets ], trace_joint = True , output_loss_index = 0 , decompositions = _get_disc_decomp ()
491- )
492- print (fx_g )
512+ # from torch._subclasses import FakeTensorMode
513+ # fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
514+ # with fake_mode:
515+ # zwyjit = JitNet(model, compute_loss)
516+ # fx_g, signature = aot_export_module(
517+ # zwyjit, [imgs, targets], trace_joint=True, output_loss_index=0, decompositions=_get_disc_decomp()
518+ # )
519+ # print('fx_g:')
520+ # fx_g.graph.print_tabular()
521+ # print('signature:', signature)
522+ # exit(0)
523+ if dump_cuda_ref :
524+ pred = model (imgs )
525+ global features_out_hook
526+ features_out_hook ['data' ] = imgs .detach ().numpy ()
527+ for name , param in model .named_parameters ():
528+ features_out_hook [name ] = param .detach ().numpy ()
529+ np .savez ('layer_outputs.npz' , ** features_out_hook )
530+ for hd in hook_handles :
531+ hd .remove ()
532+ exit (0 )
533+ else :
534+ if compiled :
535+ if joint :
536+ loss , loss_items = model_opt (imgs , targets .to (device )) # forward
537+ else :
538+ pred = model_opt (imgs )
539+ loss , loss_items = compute_loss (pred , targets .to (device )) # loss scaled by batch_size
540+ else :
541+ pred = model (imgs )
542+ loss , loss_items = compute_loss (pred , targets .to (device )) # loss scaled by batch_size
543+
493544 if RANK != - 1 :
494545 loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
495546 if opt .quad :
496547 loss *= 4.0
497548
498549 # Backward
499- scaler .scale (loss ).backward ()
550+ if opt .device == 'tpu' :
551+ loss = loss .to (device )
552+ #print('old loss:', loss, loss.device)
553+ #total_loss = 0.2666
554+ #loss.data.copy_(total_loss)
555+ print ('loss:' , loss , loss .device )
556+ if scaler is None :
557+ loss .backward ()
558+ else :
559+ scaler .scale (loss ).backward ()
500560
501561 # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
502562 if ni - last_opt_step >= accumulate :
@@ -663,7 +723,7 @@ def parse_opt(known=False):
663723 parser .add_argument ("--bucket" , type = str , default = "" , help = "gsutil bucket" )
664724 parser .add_argument ("--cache" , type = str , nargs = "?" , const = "ram" , help = "image --cache ram/disk" )
665725 parser .add_argument ("--image-weights" , action = "store_true" , help = "use weighted image selection for training" )
666- parser .add_argument ("--device" , default = "" , help = "cuda device, i.e. 0 or 0,1,2,3 or cpu" )
726+ parser .add_argument ("--device" , default = "" , help = "cuda device, i.e. 0 or 0,1,2,3 or cpu or tpu " )
667727 parser .add_argument ("--multi-scale" , action = "store_true" , help = "vary img-size +/- 50%%" )
668728 parser .add_argument ("--single-cls" , action = "store_true" , help = "train multi-class data as single-class" )
669729 parser .add_argument ("--optimizer" , type = str , choices = ["SGD" , "Adam" , "AdamW" ], default = "SGD" , help = "optimizer" )
@@ -690,6 +750,7 @@ def parse_opt(known=False):
690750 # NDJSON logging
691751 parser .add_argument ("--ndjson-console" , action = "store_true" , help = "Log ndjson to console" )
692752 parser .add_argument ("--ndjson-file" , action = "store_true" , help = "Log ndjson to file" )
753+ parser .add_argument ("--debug_cmd" , type = str , default = "" , help = "debug_cmd" )
693754
694755 return parser .parse_known_args ()[0 ] if known else parser .parse_args ()
695756
@@ -1058,7 +1119,26 @@ def run(**kwargs):
10581119 main (opt )
10591120 return opt
10601121
1122+ # import torch._dynamo
1123+ # import logging
1124+ # logger = logging.getLogger("torch._dynamo")
1125+ # logger.setLevel(logging.DEBUG)
1126+ # console_handler = logging.StreamHandler()
1127+ # console_handler.setLevel(logging.DEBUG)
1128+ # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
1129+ # console_handler.setFormatter(formatter)
1130+ # logger.addHandler(console_handler)
10611131
10621132if __name__ == "__main__" :
10631133 opt = parse_opt ()
1134+ #print_ori_fx_graph/dump_fx_graph/skip_tpu_compile/dump_bmodel_input
1135+ import tpu_mlir
1136+ tpu_mlir .python .tools .train .config .debug_cmd = opt .debug_cmd
1137+ tpu_mlir .python .tools .train .config .compile_opt = 2
1138+ # tpu_mlir.python.tools.train.config.only_compile_graph_id = 1
1139+ # tpu_mlir.python.tools.train.config.run_on_cmodel = False if opt.device == 'tpu' else True
1140+ tpu_mlir .python .tools .train .config .run_on_cmodel = True
1141+ tpu_mlir .python .tools .train .config .print_config_info ()
1142+ # torch._dynamo.config.suppress_errors = True
1143+
10641144 main (opt )
0 commit comments