Skip to content

Commit

Permalink
Backend paddle: support regularizer (#1896)
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 authored Dec 18, 2024
1 parent ec4bdd3 commit 3aca6f7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
6 changes: 5 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,11 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
list(self.net.parameters()) + self.external_trainable_variables
)
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=self.net.regularizer,
)

def train_step(inputs, targets, auxiliary_vars):
Expand Down
1 change: 1 addition & 0 deletions deepxde/nn/paddle/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class NN(paddle.nn.Layer):

def __init__(self):
super().__init__()
self.regularizer = None
self._input_transform = None
self._output_transform = None

Expand Down
29 changes: 27 additions & 2 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ def is_external_optimizer(optimizer):
return optimizer in ["L-BFGS", "L-BFGS-B"]


def get(params, optimizer, learning_rate=None, decay=None):
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
"""Retrieves an Optimizer instance."""
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer

if optimizer in ["L-BFGS", "L-BFGS-B"]:
if learning_rate is not None or decay is not None:
print("Warning: learning rate is ignored for {}".format(optimizer))
if weight_decay is not None:
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
optim = paddle.optimizer.LBFGS(
learning_rate=1,
max_iter=LBFGS_options["iter_per_step"],
Expand All @@ -46,5 +48,28 @@ def get(params, optimizer, learning_rate=None, decay=None):
learning_rate = _get_lr_scheduler(learning_rate, decay)

if optimizer == "adam":
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
return paddle.optimizer.Adam(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "sgd":
return paddle.optimizer.SGD(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "rmsprop":
return paddle.optimizer.RMSProp(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay,
)
if optimizer == "adamw":
if (
not isinstance(weight_decay, paddle.regularizer.L2Decay)
or weight_decay._coeff == 0
):
raise ValueError("AdamW optimizer requires non-zero L2 regularizer")
return paddle.optimizer.AdamW(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay._coeff,
)
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 comments on commit 3aca6f7

Please sign in to comment.