Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend PyTorch: Add L1 and L1+L2 regularizers #1905

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,14 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
False, inputs, targets, auxiliary_vars, self.data.losses_test
)

weight_decay = 0
l1_factor, l2_factor = 0, 0
if self.net.regularizer is not None:
if self.net.regularizer[0] != "l2":
raise NotImplementedError(
f"{self.net.regularizer[0]} regularization to be implemented for "
"backend pytorch"
)
weight_decay = self.net.regularizer[1]
if self.net.regularizer[0] == "l1":
l1_factor = self.net.regularizer[1]
elif self.net.regularizer[0] == "l2":
l2_factor = self.net.regularizer[1]
else:
raise ValueError(f"Unknown regularizer name: {self.net.regularizer[0]}")

optimizer_params = self.net.parameters()
if self.external_trainable_variables:
Expand All @@ -347,7 +347,7 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
optimizer_params = (
list(optimizer_params) + self.external_trainable_variables
)
if weight_decay > 0:
if l2_factor > 0:
print(
"Warning: L2 regularization will also be applied to external_trainable_variables. "
"Ensure this is intended behavior."
Expand All @@ -363,13 +363,18 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=weight_decay,
weight_decay=l2_factor,
)

def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
if l1_factor:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For l1+l2 regularization, this might not be the correct way. weight_decay in the optimizers is implemented not as the L2 loss function. We should only consider either L1 or L2, not L1 + L2 now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

l1_loss = torch.sum(
torch.stack([torch.sum(p.abs()) for p in self.net.parameters()])
)
total_loss += l1_factor * l1_loss
self.opt.zero_grad()
total_loss.backward()
return total_loss
Expand Down
Loading