|
15 | 15 | # ============================================================================= |
16 | 16 |
|
17 | 17 | import torch |
18 | | -from torch_scatter import scatter |
| 18 | +from typing import Optional |
19 | 19 |
|
20 | 20 | # ============================================================================= |
21 | 21 | # LOSS FUNCTIONS |
@@ -82,6 +82,37 @@ def forward( |
82 | 82 | descriptors_derivatives=self.descriptors_derivatives |
83 | 83 | ) |
84 | 84 |
|
| 85 | +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): |
| 86 | + """Broadcast util, from torch_scatter""" |
| 87 | + if dim < 0: |
| 88 | + dim = other.dim() + dim |
| 89 | + if src.dim() == 1: |
| 90 | + for _ in range(0, dim): |
| 91 | + src = src.unsqueeze(0) |
| 92 | + for _ in range(src.dim(), other.dim()): |
| 93 | + src = src.unsqueeze(-1) |
| 94 | + src = src.expand(other.size()) |
| 95 | + return src |
| 96 | + |
| 97 | +def scatter_sum(src: torch.Tensor, |
| 98 | + index: torch.Tensor, |
| 99 | + dim: int = -1, |
| 100 | + out: Optional[torch.Tensor] = None, |
| 101 | + dim_size: Optional[int] = None) -> torch.Tensor: |
| 102 | + """Scatter sum function, from torch_scatter module (https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py)""" |
| 103 | + index = broadcast(index, src, dim) |
| 104 | + if out is None: |
| 105 | + size = list(src.size()) |
| 106 | + if dim_size is not None: |
| 107 | + size[dim] = dim_size |
| 108 | + elif index.numel() == 0: |
| 109 | + size[dim] = 0 |
| 110 | + else: |
| 111 | + size[dim] = int(index.max()) + 1 |
| 112 | + out = torch.zeros(size, dtype=src.dtype, device=src.device) |
| 113 | + return out.scatter_add_(dim, index, src) |
| 114 | + else: |
| 115 | + return out.scatter_add_(dim, index, src) |
85 | 116 |
|
86 | 117 | def committor_loss(x: torch.Tensor, |
87 | 118 | q: torch.Tensor, |
@@ -291,7 +322,7 @@ def _get_scatter_indices(self, batch_ind, atom_ind, dim_ind): |
291 | 322 |
|
292 | 323 | # get the number of elements in each batch |
293 | 324 | # e.g. [17, 18, 18, 18] |
294 | | - batch_elements = scatter(torch.ones_like(batch_ind), batch_ind, reduce='sum') |
| 325 | + batch_elements = scatter_sum(torch.ones_like(batch_ind), batch_ind) |
295 | 326 | batch_elements[0] -= 1 # to make the later indexing consistent |
296 | 327 |
|
297 | 328 | # compute the pointer idxs to the beginning of each batch by summing the number of elements in each batch |
@@ -344,7 +375,7 @@ def _compute_square_modulus(self, x : torch.Tensor, indeces : torch.Tensor, n_at |
344 | 375 | indeces = indeces.long().to(x.device) |
345 | 376 |
|
346 | 377 | # this sums the elements of x according to the indeces, this way we get the contributions of different descriptors to the same atom |
347 | | - out = scatter(x, indeces.long()) |
| 378 | + out = scatter_sum(x, indeces.long()) |
348 | 379 | # now make the square |
349 | 380 | out = out.pow(2) |
350 | 381 | # reshape, this needs to have the correct number of atoms as we need to mulply it by the mass vector later |
|
0 commit comments