-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
1.As #606 mentioned, why was loss backward commented out? How was this issue resolved?
2.L358~L360,The model saving logic is written in the batch loop, which means that after each batch is processed, the system immediately determines whether to save the model.

However, at this point, epoch_loss is the cumulative value of all batches up to the current epoch, which seems unreasonable. This is because epoch_loss continues to increase within the epoch, and at the beginning, there are only a few batches of loss, so the value is relatively small. Therefore, the epoch_loss for the first few batches is likely to be significantly smaller than the epoch_loss after fully training an entire epoch later on. This results in saving the model state from the early part of the epoch, rather than the state after the model has been fully trained for the entire epoch.
They should be placed after the completion of one epoch of training:
for epoch in range(self.epochs):
self.model_.train()
epoch_loss = 0
for batch_x, _ in dataloader:
optimizer.zero_grad()
outputs = self.model_(batch_x)
dist = torch.sum((outputs - self.c) ** 2, dim=-1)
if self.use_ae:
loss = torch.mean(dist) + w_d + torch.mean(torch.square(outputs - batch_x))
else:
loss = torch.mean(dist) + w_d
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# After epoch
if epoch_loss < best_loss:
best_loss = epoch_loss
best_model_dict = self.model_.state_dict()