diff --git a/deepxde/gradients/gradients.py b/deepxde/gradients/gradients.py index 742796b50..b819405a5 100644 --- a/deepxde/gradients/gradients.py +++ b/deepxde/gradients/gradients.py @@ -19,7 +19,7 @@ def jacobian(ys, xs, i=None, j=None): computation. Args: - ys: Output Tensor of shape (batch_size, dim_y). + ys: Output Tensor of shape (batch_size, dim_y) or (batch_size1, batch_size2, dim_y). xs: Input Tensor of shape (batch_size, dim_x). i (int or None): `i`th row. If `i` is ``None``, returns the `j`th column J[:, `j`]. diff --git a/deepxde/gradients/gradients_forward.py b/deepxde/gradients/gradients_forward.py index b58aa7621..8fc385912 100644 --- a/deepxde/gradients/gradients_forward.py +++ b/deepxde/gradients/gradients_forward.py @@ -87,14 +87,14 @@ def grad_fn(x): # Compute J[i, j] if (i, j) not in self.J: if backend_name == "tensorflow.compat.v1": - self.J[i, j] = self.J[j][:, i : i + 1] + self.J[i, j] = self.J[j][..., i : i + 1] elif backend_name in ["tensorflow", "pytorch", "jax"]: # In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array # and a callable is returned, so that it is consistent with the argument, # which is also a tuple. This is useful for further computation, e.g., # Hessian. self.J[i, j] = ( - self.J[j][0][:, i : i + 1], + self.J[j][0][..., i : i + 1], lambda x: self.J[j][1](x)[i : i + 1], ) return self.J[i, j] diff --git a/deepxde/gradients/jacobian.py b/deepxde/gradients/jacobian.py index 5b0af016d..fd3004abe 100644 --- a/deepxde/gradients/jacobian.py +++ b/deepxde/gradients/jacobian.py @@ -20,22 +20,22 @@ def __init__(self, ys, xs): self.xs = xs if backend_name in ["tensorflow.compat.v1", "paddle"]: - self.dim_y = ys.shape[1] + self.dim_y = ys.shape[-1] elif backend_name in ["tensorflow", "pytorch"]: if config.autodiff == "reverse": # For reverse-mode AD, only a tensor is passed. - self.dim_y = ys.shape[1] + self.dim_y = ys.shape[-1] elif config.autodiff == "forward": # For forward-mode AD, a tuple of a tensor and a callable is passed, # similar to backend jax. - self.dim_y = ys[0].shape[1] + self.dim_y = ys[0].shape[-1] elif backend_name == "jax": # For backend jax, a tuple of a jax array and a callable is passed as one of # the arguments, since jax does not support computational graph explicitly. # The array is used to control the dimensions and the callable is used to # obtain the derivative function, which can be used to compute the # derivatives. - self.dim_y = ys[0].shape[1] + self.dim_y = ys[0].shape[-1] self.dim_x = xs.shape[1] self.J = {}