4
4
from .nn import NN
5
5
from .. import activations
6
6
from .. import initializers
7
+ from ..deeponet_strategy import (
8
+ SingleOutputStrategy ,
9
+ IndependentStrategy ,
10
+ SplitBothStrategy ,
11
+ SplitBranchStrategy ,
12
+ SplitTrunkStrategy ,
13
+ )
7
14
8
15
9
16
class DeepONet (NN ):
@@ -89,14 +96,40 @@ class DeepONetCartesianProd(NN):
89
96
Args:
90
97
layer_sizes_branch: A list of integers as the width of a fully connected network,
91
98
or `(dim, f)` where `dim` is the input dimension and `f` is a network
92
- function. The width of the last layer in the branch and trunk net should be
93
- equal .
99
+ function. The width of the last layer in the branch and trunk net
100
+ should be the same for all strategies except "split_branch" and "split_trunk" .
94
101
layer_sizes_trunk (list): A list of integers as the width of a fully connected
95
102
network.
96
103
activation: If `activation` is a ``string``, then the same activation is used in
97
104
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
98
105
net uses the activation `activation["trunk"]`, and the branch net uses
99
106
`activation["branch"]`.
107
+ num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
108
+ `multi_output_strategy` below should be set.
109
+ multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
110
+ "split_trunk". It makes sense to set in case of multiple outputs.
111
+
112
+ - None
113
+ Classical implementation of DeepONet with a single output.
114
+ Cannot be used with `num_outputs` > 1.
115
+
116
+ - independent
117
+ Use `num_outputs` independent DeepONets, and each DeepONet outputs only
118
+ one function.
119
+
120
+ - split_both
121
+ Split the outputs of both the branch net and the trunk net into `num_outputs`
122
+ groups, and then the kth group outputs the kth solution.
123
+
124
+ - split_branch
125
+ Split the branch net and share the trunk net. The width of the last layer
126
+ in the branch net should be equal to the one in the trunk net multiplied
127
+ by the number of outputs.
128
+
129
+ - split_trunk
130
+ Split the trunk net and share the branch net. The width of the last layer
131
+ in the trunk net should be equal to the one in the branch net multiplied
132
+ by the number of outputs.
100
133
"""
101
134
102
135
def __init__ (
@@ -105,45 +138,81 @@ def __init__(
105
138
layer_sizes_trunk ,
106
139
activation ,
107
140
kernel_initializer ,
108
- regularization = None ,
141
+ num_outputs = 1 ,
142
+ multi_output_strategy = None ,
109
143
):
110
144
super ().__init__ ()
111
145
if isinstance (activation , dict ):
112
- activation_branch = activation ["branch" ]
146
+ self . activation_branch = activation ["branch" ]
113
147
self .activation_trunk = activations .get (activation ["trunk" ])
114
148
else :
115
- activation_branch = self .activation_trunk = activations .get (activation )
116
- if callable (layer_sizes_branch [1 ]):
117
- # User-defined network
118
- self .branch = layer_sizes_branch [1 ]
119
- else :
120
- # Fully connected network
121
- self .branch = FNN (layer_sizes_branch , activation_branch , kernel_initializer )
122
- self .trunk = FNN (layer_sizes_trunk , self .activation_trunk , kernel_initializer )
123
- # register bias to parameter for updating in optimizer and storage
124
- self .b = self .create_parameter (
125
- shape = (1 ,), default_initializer = initializers .get ("zeros" )
149
+ self .activation_branch = self .activation_trunk = activations .get (activation )
150
+ self .kernel_initializer = kernel_initializer
151
+
152
+ self .num_outputs = num_outputs
153
+ if self .num_outputs == 1 :
154
+ if multi_output_strategy is not None :
155
+ raise ValueError (
156
+ "num_outputs is set to 1, but multi_output_strategy is not None."
157
+ )
158
+ elif multi_output_strategy is None :
159
+ multi_output_strategy = "independent"
160
+ print (
161
+ f"Warning: There are { num_outputs } outputs, but no multi_output_strategy selected. "
162
+ 'Use "independent" as the multi_output_strategy.'
163
+ )
164
+ self .multi_output_strategy = {
165
+ None : SingleOutputStrategy ,
166
+ "independent" : IndependentStrategy ,
167
+ "split_both" : SplitBothStrategy ,
168
+ "split_branch" : SplitBranchStrategy ,
169
+ "split_trunk" : SplitTrunkStrategy ,
170
+ }[multi_output_strategy ](self )
171
+
172
+ self .branch , self .trunk = self .multi_output_strategy .build (
173
+ layer_sizes_branch , layer_sizes_trunk
174
+ )
175
+ if isinstance (self .branch , list ):
176
+ self .branch = paddle .nn .LayerList (self .branch )
177
+ if isinstance (self .trunk , list ):
178
+ self .trunk = paddle .nn .LayerList (self .trunk )
179
+ self .b = paddle .nn .ParameterList (
180
+ [
181
+ paddle .create_parameter (
182
+ shape = [1 ,],
183
+ dtype = paddle .get_default_dtype (),
184
+ default_initializer = paddle .nn .initializer .Constant (value = 0 ),
185
+ )
186
+ for _ in range (self .num_outputs )
187
+ ]
126
188
)
127
- self .regularizer = regularization
189
+
190
+ def build_branch_net (self , layer_sizes_branch ):
191
+ # User-defined network
192
+ if callable (layer_sizes_branch [1 ]):
193
+ return layer_sizes_branch [1 ]
194
+ # Fully connected network
195
+ return FNN (layer_sizes_branch , self .activation_branch , self .kernel_initializer )
196
+
197
+ def build_trunk_net (self , layer_sizes_trunk ):
198
+ return FNN (layer_sizes_trunk , self .activation_trunk , self .kernel_initializer )
199
+
200
+ def merge_branch_trunk (self , x_func , x_loc , index ):
201
+ y = x_func @ x_loc .T
202
+ y += self .b [index ]
203
+ return y
204
+
205
+ @staticmethod
206
+ def concatenate_outputs (ys ):
207
+ return paddle .stack (ys , axis = 2 )
128
208
129
209
def forward (self , inputs ):
130
210
x_func = inputs [0 ]
131
211
x_loc = inputs [1 ]
132
- # Branch net to encode the input function
133
- x_func = self .branch (x_func )
134
- # Trunk net to encode the domain of the output function
212
+ # Trunk net input transform
135
213
if self ._input_transform is not None :
136
214
x_loc = self ._input_transform (x_loc )
137
- x_loc = self .activation_trunk (self .trunk (x_loc ))
138
- # Dot product
139
- if x_func .shape [- 1 ] != x_loc .shape [- 1 ]:
140
- raise AssertionError (
141
- "Output sizes of branch net and trunk net do not match."
142
- )
143
- x = x_func @ x_loc .T
144
- # Add bias
145
- x += self .b
146
-
215
+ x = self .multi_output_strategy .call (x_func , x_loc )
147
216
if self ._output_transform is not None :
148
217
x = self ._output_transform (inputs , x )
149
218
return x
0 commit comments