Skip to content

Commit 5a3f6dc

Browse files
authored
Merge pull request #166 from luigibonati/fix_scatter
Remove torch-scatter dependency (add only torch-scatter)
2 parents 00d639b + 90c6fa5 commit 5a3f6dc

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

devtools/conda-envs/test_env.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ dependencies:
2929
- pip:
3030
- KDEpy
3131
- nbmake
32-
- torch-scatter
3332

docs/notebooks/examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pt

docs/requirements.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ dependencies:
2626
- ipykernel
2727
- scikit-learn
2828
- scipy
29-
- torch-scatter
3029

3130
# Pip-only installs
3231
- pip:

mlcolvar/core/loss/committor_loss.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# =============================================================================
1616

1717
import torch
18-
from torch_scatter import scatter
18+
from typing import Optional
1919

2020
# =============================================================================
2121
# LOSS FUNCTIONS
@@ -82,6 +82,37 @@ def forward(
8282
descriptors_derivatives=self.descriptors_derivatives
8383
)
8484

85+
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
86+
"""Broadcast util, from torch_scatter"""
87+
if dim < 0:
88+
dim = other.dim() + dim
89+
if src.dim() == 1:
90+
for _ in range(0, dim):
91+
src = src.unsqueeze(0)
92+
for _ in range(src.dim(), other.dim()):
93+
src = src.unsqueeze(-1)
94+
src = src.expand(other.size())
95+
return src
96+
97+
def scatter_sum(src: torch.Tensor,
98+
index: torch.Tensor,
99+
dim: int = -1,
100+
out: Optional[torch.Tensor] = None,
101+
dim_size: Optional[int] = None) -> torch.Tensor:
102+
"""Scatter sum function, from torch_scatter module (https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py)"""
103+
index = broadcast(index, src, dim)
104+
if out is None:
105+
size = list(src.size())
106+
if dim_size is not None:
107+
size[dim] = dim_size
108+
elif index.numel() == 0:
109+
size[dim] = 0
110+
else:
111+
size[dim] = int(index.max()) + 1
112+
out = torch.zeros(size, dtype=src.dtype, device=src.device)
113+
return out.scatter_add_(dim, index, src)
114+
else:
115+
return out.scatter_add_(dim, index, src)
85116

86117
def committor_loss(x: torch.Tensor,
87118
q: torch.Tensor,
@@ -291,7 +322,7 @@ def _get_scatter_indices(self, batch_ind, atom_ind, dim_ind):
291322

292323
# get the number of elements in each batch
293324
# e.g. [17, 18, 18, 18]
294-
batch_elements = scatter(torch.ones_like(batch_ind), batch_ind, reduce='sum')
325+
batch_elements = scatter_sum(torch.ones_like(batch_ind), batch_ind)
295326
batch_elements[0] -= 1 # to make the later indexing consistent
296327

297328
# compute the pointer idxs to the beginning of each batch by summing the number of elements in each batch
@@ -344,7 +375,7 @@ def _compute_square_modulus(self, x : torch.Tensor, indeces : torch.Tensor, n_at
344375
indeces = indeces.long().to(x.device)
345376

346377
# this sums the elements of x according to the indeces, this way we get the contributions of different descriptors to the same atom
347-
out = scatter(x, indeces.long())
378+
out = scatter_sum(x, indeces.long())
348379
# now make the square
349380
out = out.pow(2)
350381
# reshape, this needs to have the correct number of atoms as we need to mulply it by the mass vector later

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ torch
33
numpy<2
44
pandas
55
matplotlib
6-
kdepy
7-
torch-scatter
6+
kdepy

0 commit comments

Comments
 (0)