Skip to content

Commit

Permalink
modify bc loss code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Dec 18, 2024
1 parent 85c8d2e commit 5ef4f4f
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,25 +277,30 @@ def forward_call(trunk_input):
error_f = [fi[:, bcs_start[-1] :] for fi in f]
# Each error has the shape (N1, ~N2)
losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f]

# BC loss
for j, bc in enumerate(self.pde.bcs):
beg, end = bcs_start[j], bcs_start[j + 1]
error = []
for i in range(num_func):
out = outputs[i]
if bkd.ndim(out) == 1:
out = out[:, None]
error_i = bc.error(
error_bc = []
for i in range(num_func):
error_i = []
out = outputs[i]
if bkd.ndim(out) == 1:
out = out[:, None]
for j, bc in enumerate(self.pde.bcs):
beg, end = bcs_start[j], bcs_start[j + 1]
error = bc.error(
self.train_x[1],
inputs[1],
out,
beg,
end,
aux_var=model.net.auxiliary_vars[i][:, None],
)
error.append(loss_fn(bkd.zeros_like(error_i), error_i))
losses.append(bkd.reduce_mean(bkd.stack(error, 0)))
error_i.append(loss_fn(bkd.zeros_like(error), error))

error_bc.append(error_i)

error_bc = zip(*error_bc)
error_bc = [bkd.reduce_mean(bkd.stack(error, 0)) for error in error_bc]
losses.append(error_bc)
return losses

def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
Expand Down

0 comments on commit 5ef4f4f

Please sign in to comment.