1111from typing import Callable , Literal
1212import wandb
1313
14- class DiscriminativeTrainer :
14+ class Trainer :
1515 """Trainer Class that trains 1 model instance on 1 device."""
1616 def __init__ (
1717 self ,
@@ -74,31 +74,18 @@ def _setup_model(self, model):
7474 def _setup_dataloader (self , dataset ):
7575 return DataLoader (dataset , batch_size = self .batch_size , pin_memory = True , shuffle = False , sampler = DistributedSampler (dataset ))
7676
77- def _run_batch (self , source , targets ):
78- self .optimizer .zero_grad ()
79- pred = self .model (source )
80- loss = self .loss_func (pred , targets )
81- loss .backward ()
82- self .optimizer .step ()
83- return loss .item ()
84-
85- def _run_epoch_nonCuda (self , epoch ):
86- epoch_losses = []
87- time1 = time ()
88- for source , targets in self .train_data :
89- source , targets = source .to (self .device_type ), targets .to (self .device_type )
90- batch_loss = self ._run_batch (source , targets )
91- epoch_losses .append (batch_loss )
92- if self .log_wandb :
93- wandb .log ({"epoch" : epoch , "loss" : np .mean (epoch_losses ), "epoch_time" : time ()- time1 })
94- print (f"[{ self .device_type } { self .gpu_id } ] Epoch { epoch } | Batchsize: { self .batch_size } | Steps: { len (self .train_data )} | Loss: { np .mean (epoch_losses )} | Time: { time ()- time1 :.2f} s" )
77+ def _run_batch (self , data ):
78+ raise NotImplementedError ("use dedicated subclass" )
9579
9680 def _run_epoch (self , epoch ):
9781 epoch_losses = []
9882 time1 = time ()
9983 for source , targets in self .train_data :
100- source , targets = source .to (self .gpu_id ), targets .to (self .gpu_id )
101- batch_loss = self ._run_batch (source , targets )
84+ if self .device_type == "cuda" :
85+ data = map (lambda x : x .to (self .gpu_id ))
86+ else :
87+ data = map (lambda x : x .to (self .device_type ))
88+ batch_loss = self ._run_batch (data )
10289 epoch_losses .append (batch_loss )
10390 if self .log_wandb :
10491 wandb .log ({"epoch" : epoch , "loss" : np .mean (epoch_losses ), "epoch_time" : time ()- time1 })
@@ -126,4 +113,25 @@ def train(self, max_epochs: int):
126113 else :
127114 self ._run_epoch (epoch )
128115 if (self .gpu_id == 0 ) and (epoch % self .save_every == 0 ) and (epoch != 0 ):
129- self ._save_checkpoint (epoch )
116+ self ._save_checkpoint (epoch )
117+
118+ class DiscriminativeTrainer (Trainer ):
119+ def __init__ (self , model : Module , train_data : Dataset , loss_func : Callable [..., Any ], optimizer : Optimizer , gpu_id : int , batch_size : int , save_every : int , checkpoint_folder : str , device_type : Literal ['cuda' , 'mps' , 'cpu' ], log_wandb : bool = True ) -> None :
120+ super ().__init__ (model , train_data , loss_func , optimizer , gpu_id , batch_size , save_every , checkpoint_folder , device_type , log_wandb )
121+
122+ def _run_batch (self , data ):
123+ source , targets = data
124+ self .optimizer .zero_grad ()
125+ pred = self .model (source )
126+ loss = self .loss_func (pred , targets )
127+ loss .backward ()
128+ self .optimizer .step ()
129+ return loss .item ()
130+
131+ class GenerativeTrainer (Trainer ):
132+ def __init__ (self , model : Module , train_data : Dataset , loss_func : Callable [..., Any ], optimizer : Optimizer , gpu_id : int , batch_size : int , save_every : int , checkpoint_folder : str , device_type : Literal ['cuda' , 'mps' , 'cpu' ], log_wandb : bool = True ) -> None :
133+ super ().__init__ (model , train_data , loss_func , optimizer , gpu_id , batch_size , save_every , checkpoint_folder , device_type , log_wandb )
134+
135+ def _run_batch (self , data ):
136+ self .optimizer .zero_grad ()
137+ raise NotImplementedError ("not finished yet" )
0 commit comments