diff --git a/diffusion_models/models/diffusion.py b/diffusion_models/models/diffusion.py
index ac5fd3f..7c58a69 100644
--- a/diffusion_models/models/diffusion.py
+++ b/diffusion_models/models/diffusion.py
@@ -219,12 +219,12 @@ def denoise_singlestep(
         x
             tensor representing a batch of noisy pictures (may be of different timesteps)
         t
-            tensor representing the t timesteps for the batch
+            tensor representing the t timesteps for the batch (where the batch now is)
 
         Returns
         -------
         out
-            less noisy version (by one timestep)
+            less noisy version (by one timestep, now at t-1 from the arguments)
         """
         self.model.eval()
         with torch.no_grad():
@@ -243,9 +243,7 @@ def denoise_singlestep(
     
     def sample(
             self,
-            num_samples: int,
-            debugging: bool=False,
-            save_every: int=20
+            num_samples: int
         ) -> Float[Tensor, "batch channel height width"]:
         beta = self.fwd_diff.betas[-1].view(-1,1,1,1)
         x = self.init_noise(num_samples) * torch.sqrt(beta)
@@ -253,54 +251,6 @@ def sample(
             t = i * torch.ones((num_samples), dtype=torch.long, device=list(self.model.parameters())[0].device)
             x = self.denoise_singlestep(x, t)
         return x
-    
-    def sample2(
-            self, 
-            num_samples: int, 
-            debugging: bool=False,
-            save_every: int=20
-        ) -> Union[Float[Tensor, "batch channel height width"], List[Float[Tensor, "batch channel height width"]]]:
-        """Sample a batch of images.
-
-        Parameters
-        ----------
-        num_samples
-            how big the batch should be
-        debugging
-            if true, returns list that shows the sampling process
-        save_every
-            defines how often the tensors should be saved in the denoising process
-
-        Returns
-        -------
-        out
-            either a list of tensors if debugging is true, else a single tensor with final images
-        """
-        self.model.eval()
-        device = list(self.parameters())[0].device
-        with torch.no_grad():
-            x = torch.randn((num_samples, self.model.in_channels, self.img_size, self.img_size), device=device)
-            x_list = []
-            for i in reversed(range(1, self.fwd_diff.timesteps)):
-                t_step = i * torch.ones((num_samples), dtype=torch.long, device=device)
-                t_enc = self.time_encoder.get_pos_encoding(t_step)
-                noise_pred = self.model(x, t_enc)
-
-                alpha = self.fwd_diff.alphas[t_step][:, None, None, None]
-                alpha_hat = self.fwd_diff.alphas_dash[t_step][:, None, None, None]
-                beta = self.fwd_diff.betas[t_step][:, None, None, None]
-                if i > 1:
-                    noise = torch.randn_like(x, device=device)
-                else:
-                    noise = torch.zeros_like(x, device=device)
-                # mean is predicted by NN and refactored by alphas, beta is kept constant according to scheduler
-                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise
-                if debugging and (i % save_every == 0):
-                    x_list.append(x)
-        self.model.train()
-        if debugging:
-            return x_list
-        return x
 
     def _sample_timesteps(self, batch_size: int, device: torch.device) -> Int64[Tensor, "batch"]:
         return torch.randint(low=1, high=self.fwd_diff.timesteps, size=(batch_size,), device=device)
\ No newline at end of file
diff --git a/diffusion_models/utils/trainer.py b/diffusion_models/utils/trainer.py
index 11279d2..b4d954a 100644
--- a/diffusion_models/utils/trainer.py
+++ b/diffusion_models/utils/trainer.py
@@ -1,6 +1,7 @@
 import os
 import torch
 import torch.nn as nn
+import torch.distributed as dist
 from torch.utils.data import Dataset, DataLoader
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.utils.data.distributed import DistributedSampler
@@ -9,12 +10,15 @@
 from time import time
 import wandb
 from typing import Callable, Literal, Any, Tuple
+from torch import Tensor
 import wandb
 from torch.nn import Module
 import torchvision
+from math import isqrt
+from jaxtyping import Float
 
 class Trainer:
-    """Trainer Class that trains 1 model instance on 1 device."""
+    """Trainer Class that trains 1 model instance on 1 device, suited for distributed training."""
     def __init__(
         self,
         model: nn.Module,
@@ -22,6 +26,7 @@ def __init__(
         loss_func: Callable,
         optimizer: Optimizer,
         gpu_id: int,
+        num_gpus: int,
         batch_size: int,
         save_every: int,
         checkpoint_folder: str,
@@ -41,69 +46,102 @@ def __init__(
         optimizer
             torch.optim instance with model.parameters and learning rate passed
         gpu_id
-            int in range [0, num_GPUs]
+            int in range [0, num_GPUs], value does not matter if `device_type!="cuda"`
+        num_gpus
+            does not matter if `device_type!="cuda"`
         save_every
-            how often (epochs) to save model checkpoint
+            checkpoint model & upload data to wandb every `save_every` epoch
         checkpoint_folder
-            where to save checkpoint to
+            where to save checkpoints to
         device_type
             specify in case not training no CUDA capable device
         log_wandb
-            whether to log to wandb; requires that initialization has been done
+            whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)
         """
         self.device_type = device_type
         self.gpu_id = gpu_id
+        self.num_gpus = num_gpus
         self.batch_size = batch_size
         if device_type != "cuda":
+            # distributed training not supported for devices other than CUDA
             self.gpu_id = 0
             self.model = model.to(torch.device(device_type))
             self.train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
         else:
-            self.model = self._setup_model(model)
-            self.train_data = self._setup_dataloader(train_data)
+            # this works for single and multi-GPU setups
+            self.model = self._setup_model(model) # self.model will be DistributedDataParallel-wrapped model
+            self.train_data = self._setup_dataloader(train_data) # self.train_data will be DataLoader with DistributedSampler
         self.loss_func = loss_func
         self.optimizer = optimizer
         self.save_every = save_every
         self.checkpoint_folder = checkpoint_folder
-        self.log_wandb = log_wandb and (self.gpu_id==0)
+        self.log_wandb = log_wandb and (self.gpu_id==0) # only log if in process for GPU 0
         if self.log_wandb:
             wandb.watch(self.model, log="all", log_freq=save_every)
+        self.loss_history = []
 
-    def _setup_model(self, model):
+    def _setup_model(self, model: nn.Module):
         model = model.to(self.gpu_id)
         return DDP(model, device_ids=[self.gpu_id])
     
-    def _setup_dataloader(self, dataset):
+    def _setup_dataloader(self, dataset: Dataset):
         return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))
 
-    def _run_batch(self, data):
-        raise NotImplementedError("use dedicated subclass")
+    def _run_batch(self, data: Tuple):
+        raise NotImplementedError("Use dedicated subclass (generative/discriminative) of Trainer to run a mini-batch of data.")
 
-    def _run_epoch(self, epoch):
+    def _run_epoch(self, epoch: int):
         epoch_losses = []
-        time1 = time()
+        epoch_time1 = time()
         for data in self.train_data:
+            batch_time1 = time()
             if self.device_type == "cuda":
+                # move all data inputs onto GPU
                 data = tuple(map(lambda x: x.to(self.gpu_id), data))
             else:
                 data = tuple(map(lambda x: x.to(self.device_type), data))
             batch_loss = self._run_batch(data)
             epoch_losses.append(batch_loss)
+            if self.log_wandb:
+                wandb.log({"epoch": epoch, "loss": batch_loss, "batch_time": time()-batch_time1})
         if self.log_wandb:
-            wandb.log({"epoch": epoch, "loss": np.mean(epoch_losses), "epoch_time": time()-time1})
-        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-time1:.2f}s")
+            wandb.log({"epoch_loss": np.mean(epoch_losses), "epoch_time": time()-epoch_time1})
+        self.loss_history.append(np.mean(epoch_losses))
+        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses)} | Time: {time()-epoch_time1:.2f}s")
 
-    def _save_checkpoint(self, epoch):
+    def _save_checkpoint(self, epoch: int):
         if self.device_type == "cuda":
+            # for DistributedDataParallel-wrapped model (nn.Module)
             ckp = self.model.module.state_dict()
         else:
             ckp = self.model.state_dict()
         if not os.path.isdir(self.checkpoint_folder):
             os.makedirs(self.checkpoint_folder)
         path = os.path.join(self.checkpoint_folder, f"checkpoint{epoch}.pt")
-        torch.save(ckp, path)
+        torch.save(
+            {
+                "epoch": epoch,
+                "model_state_dict": ckp,
+                "optimizer_state_dict": self.optimizer.state_dict(),
+                "loss": self.loss_history[-1],
+                "device_type": self.device_type
+            },
+            path
+        )
         print(f"Epoch {epoch} | Training checkpoint saved at {path}")
 
+    def load_checkpoint(self, checkpoint_path: str):
+        map_location = None
+        if ckp["device_type"] != self.device_type:
+            map_location = torch.device(self.device_type)
+        ckp = torch.load(checkpoint_path, map_location=map_location)
+        if self.device_type == "cuda":
+            self.model.module.load_state_dict(ckp["model_state_dict"])
+        else:
+            self.model.load_state_dict(ckp["model_state_dict"])
+        self.optimizer.load_state_dict(ckp["optimizer_state_dict"])
+        self.loss_history.append(ckp["loss"])
+
     def train(self, max_epochs: int):
         """Train method of Trainer class.
         
@@ -117,15 +155,9 @@ def train(self, max_epochs: int):
             if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
                 self._save_checkpoint(epoch)
 
-    def load_checkpoint(self, checkpoint_path: str):
-        if self.device_type == "cuda":
-            self.model.module.load_state_dict(torch.load(checkpoint_path))
-        else:
-            self.model.load_state_dict(torch.load(checkpoint_path))
-
 class DiscriminativeTrainer(Trainer):
-    def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
-        super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
+    def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, num_gpus: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
+        super().__init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
 
     def _run_batch(self, data):
         """Run a data batch.
@@ -133,11 +165,11 @@ def _run_batch(self, data):
         Parameters
         ----------
         data
-            tuple of training batch and targets
+            tuple of input data and targets, several inputs are possible: (input1, input2, ..., inputN, target)
         """
-        source, targets = data
+        *source, targets = data
         self.optimizer.zero_grad()
-        pred = self.model(source)
+        pred = self.model(*source)
         loss = self.loss_func(pred, targets)
         loss.backward()
         self.optimizer.step()
@@ -150,22 +182,36 @@ def __init__(
             train_data: Dataset, 
             loss_func: Callable[..., Any], 
             optimizer: Optimizer, 
-            gpu_id: int, 
+            gpu_id: int,
+            num_gpus: int,
             batch_size: int, 
             save_every: int, 
             checkpoint_folder: str, 
             device_type: Literal['cuda', 'mps', 'cpu'], 
             log_wandb: bool,
-            num_samples: int,
-            show_denoising_process: bool,
-            show_denoising_every: int
+            num_samples: int
         ) -> None:
-        super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
+        """Constructor of GenerativeTrainer class.
+
+        Parameters
+        ----------
+        model
+            instance of nn.Module, must implement a `sample(num_samples: int)` method
+        """
+        super().__init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
+
+        def is_square(i: int) -> bool:
+            return i == isqrt(i) ** 2
+            
+        def closest_square_divisible_by(num_samples: int, div: int):
+            counter = 1
+            while (counter**2 % div != 0) and (counter**2 < num_samples):
+                counter += 1
+            return counter**2
+        
+        if (num_samples % self.num_gpus != 0) or (not is_square(num_samples)):
+            num_samples = closest_square_divisible_by(num_samples, self.num_gpus)
         self.num_samples = num_samples
-        if not np.sqrt(num_samples).is_integer():
-            raise ValueError("Please choose a num_samples value with integer sqrt.")
-        self.show_denoising_process = show_denoising_process
-        self.show_denoising_every = show_denoising_every
 
     def _run_batch(self, data):
         """Run a data batch.
@@ -173,44 +219,41 @@ def _run_batch(self, data):
         Parameters
         ----------
         data
-            single item tuple of training batch
+            tuple containing training batch
         """
         self.optimizer.zero_grad()
-        ### to be changed!
-        pred = self.model(data[0])
+        pred = self.model(*data)
         loss = self.loss_func(*pred)
         loss.backward()
         self.optimizer.step()
         return loss.item()
     
+    def _wandb_log_sample(self, sample: Float[Tensor, "channels height width"], epoch: int):
+        images = wandb.Image(sample, caption=f"Samples Epoch {epoch}")
+        wandb.log({"examples": images}, commit=False)
+    
+    def _save_samples(self, samples: Float[Tensor, "samples channels height width"], storage_folder: str, epoch: int):
+        samples = torchvision.utils.make_grid(samples, nrow=int(np.sqrt(self.num_samples)))
+        path = os.path.join(self.checkpoint_folder, f"samples_epoch{epoch}.png")
+        torchvision.utils.save_image(samples, path)
+        print(f"Epoch {epoch} | Samples saved at {path}")
+        if self.log_wandb:
+            self._wandb_log_sample(samples, epoch)
+
+    def get_samples(self, num_samples: int):
+        if (self.device_type == "cuda") and (self.num_gpus == 1):
+            samples = self.model.module.sample(self.num_samples)
+        if (self.device_type == "cuda") and (self.num_gpus > 1):
+            samples = self.model.module.sample(int(self.num_samples//self.num_gpus))
+            total_samples = torch.zeros(samples.shape[0]*self.num_gpus, device=samples.device)
+            dist.all_gather_into_tensor(total_samples, samples)
+            samples = total_samples
+        else:
+            samples = self.model.sample(self.num_samples)
+        return samples
+    
     def _save_checkpoint(self, epoch: int):
         """Overwriting original method - Checkpoint model and generate samples."""
         super()._save_checkpoint(epoch)
-        if self.device_type == "cuda":
-            samples = self.model.module.sample(self.num_samples, debugging=self.show_denoising_process, save_every=self.show_denoising_every)
-        else:
-            samples = self.model.sample(self.num_samples, debugging=self.show_denoising_process, save_every=self.show_denoising_every)
-
-        if not self.show_denoising_process:
-            samples = torchvision.utils.make_grid(samples, nrow=int(np.sqrt(self.num_samples)))
-            if self.log_wandb:
-                images = wandb.Image(
-                    samples, 
-                    caption=f"Samples Epoch {epoch}"
-                )
-                wandb.log({"examples": images})
-            path = os.path.join(self.checkpoint_folder, f"samples_epoch{epoch}.png")
-            torchvision.utils.save_image(samples, path)
-            print(f"Epoch {epoch} | Samples saved at {path}")
-        else:
-            path = os.path.join(self.checkpoint_folder, f"samples_epoch{epoch}")
-            if not os.path.isdir(path):
-                os.makedirs(path)
-            for i, sample in enumerate(samples):
-                grid = torchvision.utils.make_grid(sample, nrow=int(np.sqrt(self.num_samples)))
-                img_path = os.path.join(path, f"samples_step{i * self.show_denoising_every}.png")
-                torchvision.utils.save_image(grid, img_path)
-                if self.log_wandb:
-                    images = wandb.Image(grid)
-                    wandb.log({f"examples_epoch{epoch}": images})
-            print(f"Epoch {epoch} | Samples saved at {path}")
\ No newline at end of file
+        samples = self.get_samples(self.num_samples)
+        self._save_samples(self.num_samples, self.checkpoint_folder, epoch)
\ No newline at end of file