diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 18b569b8b8f5..01951f823850 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1804,7 +1804,10 @@ def _inner_training_loop( profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1)) profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000)) profile_logdir = os.environ.get('PROFILE_LOGDIR', None) + + self.num_compilations = 0 for step, inputs in enumerate(epoch_iterator): + self.last_time_stamp = time.time() if step == 0 and epoch == 0: print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()}) total_batched_samples += 1 @@ -1896,6 +1899,18 @@ def _inner_training_loop( if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() + + xm.mark_step() + if self.num_compilations != met.metric_data('CompileTime')[:1] : + self.num_compilations = met.metric_data('CompileTime')[:1] + else: + xm.rendezvous('step') + step_time = time.time() - self.last_time_stamp + data, fsdp, mdl = self.args.spmd_mesh.ici_mesh_shape + num_devices = data * fsdp * mdl + num_tokens = inputs["input_ids"].numel() / num_devices + xm.master_print(f"Step time: {step_time}: Model TFLOPS: {self.model_flops(step_time, num_tokens)}") + model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch @@ -1905,6 +1920,7 @@ def _inner_training_loop( else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + if self.control.should_epoch_stop or self.control.should_training_stop: break @@ -2694,8 +2710,21 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, else: self.accelerator.backward(loss) + + # TODO: implement memory info for PJRT + #xm.master_print(f"Memory Info: {xm.get_memory_info(xm.xla_device())}") + + return loss.detach() / self.args.gradient_accumulation_steps + def model_flops(self, step_time, num_tokens): + num_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + model_flops = 6 * num_trainable_params * num_tokens + model_tflops_per_second = model_flops / step_time / 1e12 + return model_tflops_per_second + + + def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -3169,6 +3198,9 @@ def evaluation_loop( if is_torch_tpu_available(): xm.mark_step() + + + # Update containers on host if loss is not None: losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))