Skip to content

Commit 439039c

Browse files
Backend paddle: DeepONetCartesianProd supports multi outputs (#1799)
1 parent 6e76854 commit 439039c

File tree

2 files changed

+99
-30
lines changed

2 files changed

+99
-30
lines changed

deepxde/nn/paddle/deeponet.py

+98-29
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
from .nn import NN
55
from .. import activations
66
from .. import initializers
7+
from ..deeponet_strategy import (
8+
SingleOutputStrategy,
9+
IndependentStrategy,
10+
SplitBothStrategy,
11+
SplitBranchStrategy,
12+
SplitTrunkStrategy,
13+
)
714

815

916
class DeepONet(NN):
@@ -89,14 +96,40 @@ class DeepONetCartesianProd(NN):
8996
Args:
9097
layer_sizes_branch: A list of integers as the width of a fully connected network,
9198
or `(dim, f)` where `dim` is the input dimension and `f` is a network
92-
function. The width of the last layer in the branch and trunk net should be
93-
equal.
99+
function. The width of the last layer in the branch and trunk net
100+
should be the same for all strategies except "split_branch" and "split_trunk".
94101
layer_sizes_trunk (list): A list of integers as the width of a fully connected
95102
network.
96103
activation: If `activation` is a ``string``, then the same activation is used in
97104
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
98105
net uses the activation `activation["trunk"]`, and the branch net uses
99106
`activation["branch"]`.
107+
num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
108+
`multi_output_strategy` below should be set.
109+
multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
110+
"split_trunk". It makes sense to set in case of multiple outputs.
111+
112+
- None
113+
Classical implementation of DeepONet with a single output.
114+
Cannot be used with `num_outputs` > 1.
115+
116+
- independent
117+
Use `num_outputs` independent DeepONets, and each DeepONet outputs only
118+
one function.
119+
120+
- split_both
121+
Split the outputs of both the branch net and the trunk net into `num_outputs`
122+
groups, and then the kth group outputs the kth solution.
123+
124+
- split_branch
125+
Split the branch net and share the trunk net. The width of the last layer
126+
in the branch net should be equal to the one in the trunk net multiplied
127+
by the number of outputs.
128+
129+
- split_trunk
130+
Split the trunk net and share the branch net. The width of the last layer
131+
in the trunk net should be equal to the one in the branch net multiplied
132+
by the number of outputs.
100133
"""
101134

102135
def __init__(
@@ -105,45 +138,81 @@ def __init__(
105138
layer_sizes_trunk,
106139
activation,
107140
kernel_initializer,
108-
regularization=None,
141+
num_outputs=1,
142+
multi_output_strategy=None,
109143
):
110144
super().__init__()
111145
if isinstance(activation, dict):
112-
activation_branch = activation["branch"]
146+
self.activation_branch = activation["branch"]
113147
self.activation_trunk = activations.get(activation["trunk"])
114148
else:
115-
activation_branch = self.activation_trunk = activations.get(activation)
116-
if callable(layer_sizes_branch[1]):
117-
# User-defined network
118-
self.branch = layer_sizes_branch[1]
119-
else:
120-
# Fully connected network
121-
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
122-
self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer)
123-
# register bias to parameter for updating in optimizer and storage
124-
self.b = self.create_parameter(
125-
shape=(1,), default_initializer=initializers.get("zeros")
149+
self.activation_branch = self.activation_trunk = activations.get(activation)
150+
self.kernel_initializer = kernel_initializer
151+
152+
self.num_outputs = num_outputs
153+
if self.num_outputs == 1:
154+
if multi_output_strategy is not None:
155+
raise ValueError(
156+
"num_outputs is set to 1, but multi_output_strategy is not None."
157+
)
158+
elif multi_output_strategy is None:
159+
multi_output_strategy = "independent"
160+
print(
161+
f"Warning: There are {num_outputs} outputs, but no multi_output_strategy selected. "
162+
'Use "independent" as the multi_output_strategy.'
163+
)
164+
self.multi_output_strategy = {
165+
None: SingleOutputStrategy,
166+
"independent": IndependentStrategy,
167+
"split_both": SplitBothStrategy,
168+
"split_branch": SplitBranchStrategy,
169+
"split_trunk": SplitTrunkStrategy,
170+
}[multi_output_strategy](self)
171+
172+
self.branch, self.trunk = self.multi_output_strategy.build(
173+
layer_sizes_branch, layer_sizes_trunk
174+
)
175+
if isinstance(self.branch, list):
176+
self.branch = paddle.nn.LayerList(self.branch)
177+
if isinstance(self.trunk, list):
178+
self.trunk = paddle.nn.LayerList(self.trunk)
179+
self.b = paddle.nn.ParameterList(
180+
[
181+
paddle.create_parameter(
182+
shape=[1,],
183+
dtype=paddle.get_default_dtype(),
184+
default_initializer=paddle.nn.initializer.Constant(value=0),
185+
)
186+
for _ in range(self.num_outputs)
187+
]
126188
)
127-
self.regularizer = regularization
189+
190+
def build_branch_net(self, layer_sizes_branch):
191+
# User-defined network
192+
if callable(layer_sizes_branch[1]):
193+
return layer_sizes_branch[1]
194+
# Fully connected network
195+
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
196+
197+
def build_trunk_net(self, layer_sizes_trunk):
198+
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
199+
200+
def merge_branch_trunk(self, x_func, x_loc, index):
201+
y = x_func @ x_loc.T
202+
y += self.b[index]
203+
return y
204+
205+
@staticmethod
206+
def concatenate_outputs(ys):
207+
return paddle.stack(ys, axis=2)
128208

129209
def forward(self, inputs):
130210
x_func = inputs[0]
131211
x_loc = inputs[1]
132-
# Branch net to encode the input function
133-
x_func = self.branch(x_func)
134-
# Trunk net to encode the domain of the output function
212+
# Trunk net input transform
135213
if self._input_transform is not None:
136214
x_loc = self._input_transform(x_loc)
137-
x_loc = self.activation_trunk(self.trunk(x_loc))
138-
# Dot product
139-
if x_func.shape[-1] != x_loc.shape[-1]:
140-
raise AssertionError(
141-
"Output sizes of branch net and trunk net do not match."
142-
)
143-
x = x_func @ x_loc.T
144-
# Add bias
145-
x += self.b
146-
215+
x = self.multi_output_strategy.call(x_func, x_loc)
147216
if self._output_transform is not None:
148217
x = self._output_transform(inputs, x)
149218
return x

examples/operator/stokes_aligned_zcs_pideeponet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow, pytorch"""
1+
"""Backend supported: tensorflow, pytorch, paddle"""
22
import deepxde as dde
33
import matplotlib.pyplot as plt
44
import numpy as np

0 commit comments

Comments
 (0)