Skip to content

Commit eca0300

Browse files
committed
rename shift to shift_pos
1 parent c46fc44 commit eca0300

File tree

8 files changed

+46
-40
lines changed

8 files changed

+46
-40
lines changed

tests/functions_tests/test_triplets.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ def test_calc_triplets():
1313
dtype=torch.long,
1414
device=device,
1515
)
16-
shift = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
17-
shift[:, 0] = torch.tensor(
16+
shift_pos = torch.zeros((edge_index.shape[1], 3), dtype=torch.float32, device=device)
17+
shift_pos[:, 0] = torch.tensor(
1818
[1, 2, 3, 4, 5, 6, -1, -2, -3, -4, -5, -6], dtype=torch.float32, device=device
1919
)
2020
# print("shift", shift.shape)
21-
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(edge_index, shift)
21+
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
22+
edge_index, shift_pos
23+
)
2224
# print("triplet_node_index", triplet_node_index.shape, triplet_node_index)
2325
# print("multiplicity", multiplicity.shape, multiplicity)
2426
# print("triplet_shift", triplet_shift.shape, triplet_shift)
@@ -44,9 +46,9 @@ def test_calc_triplets():
4446
# shift for edge `i->j`, `i->k`, `j->k`.
4547
triplet_shift = torch.stack(
4648
[
47-
-shift[edge_jk[:, 0]],
48-
-shift[edge_jk[:, 1]],
49-
shift[edge_jk[:, 0]] - shift[edge_jk[:, 1]],
49+
-shift_pos[edge_jk[:, 0]],
50+
-shift_pos[edge_jk[:, 1]],
51+
shift_pos[edge_jk[:, 0]] - shift_pos[edge_jk[:, 1]],
5052
],
5153
dim=1,
5254
)

torch_dftd/functions/dftd3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def edisp(
117117
cnthr: Optional[float] = None,
118118
batch: Optional[Tensor] = None,
119119
batch_edge: Optional[Tensor] = None,
120-
shift: Optional[Tensor] = None,
120+
shift_pos: Optional[Tensor] = None,
121121
pos: Optional[Tensor] = None,
122122
cell: Optional[Tensor] = None,
123123
r2=None,
@@ -146,7 +146,7 @@ def edisp(
146146
cnthr (float or None): cutoff distance for coordination number calculation in **bohr**
147147
batch (Tensor or None): (n_atoms,)
148148
batch_edge (Tensor or None): (n_edges,)
149-
shift (Tensor or None): (n_atoms,) used to calculate 3-body term when abc=True
149+
shift_pos (Tensor or None): (n_atoms,) used to calculate 3-body term when abc=True
150150
pos (Tensor): (n_atoms, 3) position in **bohr**
151151
cell (Tensor): (3, 3) cell size in **bohr**
152152
r2 (Tensor or None):
@@ -267,7 +267,7 @@ def edisp(
267267
edge_index_abc = edge_index[:, within_cutoff]
268268
batch_edge_abc = None if batch_edge is None else batch_edge[within_cutoff]
269269
# c6_abc = c6[within_cutoff]
270-
shift_abc = None if shift is None else shift[within_cutoff]
270+
shift_abc = None if shift_pos is None else shift_pos[within_cutoff]
271271

272272
n_atoms = Z.shape[0]
273273
if not bidirectional:
@@ -290,7 +290,7 @@ def edisp(
290290
batch_triplets: Optional[Tensor]
291291
triplet_node_index, multiplicity, edge_jk, batch_triplets = calc_triplets(
292292
edge_index_abc,
293-
shift=shift_abc,
293+
shift_pos=shift_abc,
294294
dtype=pos.dtype,
295295
batch_edge=batch_edge_abc,
296296
)

torch_dftd/functions/distance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def calc_distances(
88
pos: Tensor,
99
edge_index: Tensor,
1010
cell: Optional[Tensor] = None,
11-
shift: Optional[Tensor] = None,
11+
shift_pos: Optional[Tensor] = None,
1212
eps=1e-20,
1313
) -> Tensor:
1414
"""Distance calculation function.
@@ -17,7 +17,7 @@ def calc_distances(
1717
pos (Tensor): (n_atoms, 3) atom positions.
1818
edge_index (Tensor): (2, n_edges) edge_index for graph.
1919
cell (Tensor): cell size, None for non periodic system.
20-
shift (Tensor): (n_edges, 3) position shift vectors of edges owing to the periodic boundary. It should be length unit.
20+
shift_pos (Tensor): (n_edges, 3) position shift vectors of edges owing to the periodic boundary. It should be length unit.
2121
eps (float): Small float value to avoid NaN in backward when the distance is 0.
2222
2323
Returns:
@@ -30,7 +30,7 @@ def calc_distances(
3030
Ri = pos[idx_i]
3131
Rj = pos[idx_j]
3232
if cell is not None:
33-
Rj += shift
33+
Rj += shift_pos
3434
# eps is to avoid Nan in backward when Dij = 0 with sqrt.
3535
Dij = torch.sqrt(torch.sum((Ri - Rj) ** 2, dim=-1) + eps)
3636
return Dij

torch_dftd/functions/triplets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
def calc_triplets(
99
edge_index: Tensor,
10-
shift: Optional[Tensor] = None,
10+
shift_pos: Optional[Tensor] = None,
1111
dtype=torch.float32,
1212
batch_edge: Optional[Tensor] = None,
1313
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1414
"""Calculate triplet edge index.
1515
1616
Args:
1717
edge_index (Tensor): (2, n_edges) edge_index for graph. It must be bidirectional edge.
18-
shift (Tensor or None): (n_edges, 3) used to calculate unique atoms when pbc=True.
18+
shift_pos (Tensor or None): (n_edges, 3) used to calculate unique atoms when pbc=True.
1919
dtype: dtype for `multiplicity`
2020
batch_edge (Tensor or None): Specify batch indices for `edge_index`.
2121
@@ -37,10 +37,10 @@ def calc_triplets(
3737
src = src[sort_inds]
3838
dst = dst[sort_inds]
3939

40-
if shift is None:
40+
if shift_pos is None:
4141
edge_indices = torch.arange(src.shape[0], dtype=torch.long, device=edge_index.device)
4242
else:
43-
edge_indices = torch.arange(shift.shape[0], dtype=torch.long, device=edge_index.device)
43+
edge_indices = torch.arange(shift_pos.shape[0], dtype=torch.long, device=edge_index.device)
4444
edge_indices = edge_indices[is_larger][sort_inds]
4545

4646
if batch_edge is None:

torch_dftd/nn/base_dftd_module.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def calc_energy_batch(
1717
edge_index: Tensor,
1818
cell: Optional[Tensor] = None,
1919
pbc: Optional[Tensor] = None,
20-
shift: Optional[Tensor] = None,
20+
shift_pos: Optional[Tensor] = None,
2121
batch: Optional[Tensor] = None,
2222
batch_edge: Optional[Tensor] = None,
2323
damping: str = "zero",
@@ -32,7 +32,7 @@ def calc_energy_batch(
3232
edge_index (Tensor): (2, n_edges) edge index within cutoff
3333
cell (Tensor): (n_atoms, 3) cell size in angstrom, None for non periodic system.
3434
pbc (Tensor): (bs, 3) pbc condition, None for non periodic system.
35-
shift (Tensor): (n_atoms, 3) shift vector
35+
shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
3636
batch (Tensor): (n_atoms,) Specify which graph this atom belongs to
3737
batch_edge (Tensor): (n_edges, 3) Specify which graph this edge belongs to
3838
damping (str):
@@ -49,7 +49,7 @@ def calc_energy(
4949
edge_index: Tensor,
5050
cell: Optional[Tensor] = None,
5151
pbc: Optional[Tensor] = None,
52-
shift: Optional[Tensor] = None,
52+
shift_pos: Optional[Tensor] = None,
5353
batch: Optional[Tensor] = None,
5454
batch_edge: Optional[Tensor] = None,
5555
damping: str = "zero",
@@ -64,6 +64,7 @@ def calc_energy(
6464
edge_index (Tensor):
6565
cell (Tensor): cell size in angstrom, None for non periodic system.
6666
pbc (Tensor): pbc condition, None for non periodic system.
67+
shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
6768
batch (Tensor):
6869
batch_edge (Tensor):
6970
damping (str): damping method. "zero", "bj", "zerom", "bjm"
@@ -73,7 +74,7 @@ def calc_energy(
7374
"""
7475
with torch.no_grad():
7576
E_disp = self.calc_energy_batch(
76-
Z, pos, edge_index, cell, pbc, shift, batch, batch_edge, damping=damping
77+
Z, pos, edge_index, cell, pbc, shift_pos, batch, batch_edge, damping=damping
7778
)
7879
if batch is None:
7980
return [{"energy": E_disp.item()}]
@@ -91,7 +92,7 @@ def calc_energy_and_forces(
9192
edge_index: Tensor,
9293
cell: Optional[Tensor] = None,
9394
pbc: Optional[Tensor] = None,
94-
shift: Optional[Tensor] = None,
95+
shift_pos: Optional[Tensor] = None,
9596
batch: Optional[Tensor] = None,
9697
batch_edge: Optional[Tensor] = None,
9798
damping: str = "zero",
@@ -103,6 +104,7 @@ def calc_energy_and_forces(
103104
pos (Tensor): atom positions in angstrom
104105
cell (Tensor): cell size in angstrom, None for non periodic system.
105106
pbc (Tensor): pbc condition, None for non periodic system.
107+
shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
106108
damping (str): damping method. "zero", "bj", "zerom", "bjm"
107109
108110
Returns:
@@ -117,11 +119,11 @@ def calc_energy_and_forces(
117119
# We need to explicitly include this dependency to calculate cell gradient
118120
# for stress computation.
119121
# pos is assumed to be inside "cell", so relative position `rel_pos` lies between 0~1.
120-
assert isinstance(shift, Tensor)
121-
shift.requires_grad_(True)
122+
assert isinstance(shift_pos, Tensor)
123+
shift_pos.requires_grad_(True)
122124

123125
E_disp = self.calc_energy_batch(
124-
Z, pos, edge_index, cell, pbc, shift, batch, batch_edge, damping=damping
126+
Z, pos, edge_index, cell, pbc, shift_pos, batch, batch_edge, damping=damping
125127
)
126128

127129
E_disp.sum().backward()
@@ -140,7 +142,7 @@ def calc_energy_and_forces(
140142
if cell is not None:
141143
# stress = torch.mm(cell_grad, cell.T) / cell_volume
142144
# Get stress in Voigt notation (xx, yy, zz, yz, xz, xy)
143-
assert isinstance(shift, Tensor)
145+
assert isinstance(shift_pos, Tensor)
144146
voigt_left = [0, 1, 2, 1, 2, 0]
145147
voigt_right = [0, 1, 2, 2, 0, 1]
146148
if batch is None:
@@ -149,7 +151,8 @@ def calc_energy_and_forces(
149151
(pos[:, voigt_left] * pos.grad[:, voigt_right]).to(torch.float64), dim=0
150152
)
151153
cell_grad += torch.sum(
152-
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64), dim=0
154+
(shift_pos[:, voigt_left] * shift_pos.grad[:, voigt_right]).to(torch.float64),
155+
dim=0,
153156
)
154157
stress = cell_grad.to(cell.dtype) / cell_volume
155158
results_list[0]["stress"] = stress.detach().cpu().numpy()
@@ -166,7 +169,7 @@ def calc_energy_and_forces(
166169
cell_grad.scatter_add_(
167170
0,
168171
batch_edge.view(batch_edge.size()[0], 1).expand(batch_edge.size()[0], 6),
169-
(shift[:, voigt_left] * shift.grad[:, voigt_right]).to(torch.float64),
172+
(shift_pos[:, voigt_left] * shift_pos.grad[:, voigt_right]).to(torch.float64),
170173
)
171174
stress = cell_grad.to(cell.dtype) / cell_volume[:, None]
172175
stress = stress.detach().cpu().numpy()

torch_dftd/nn/dftd2_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,19 @@ def calc_energy_batch(
4848
edge_index: Tensor,
4949
cell: Optional[Tensor] = None,
5050
pbc: Optional[Tensor] = None,
51-
shift: Optional[Tensor] = None,
51+
shift_pos: Optional[Tensor] = None,
5252
batch: Optional[Tensor] = None,
5353
batch_edge: Optional[Tensor] = None,
5454
damping: str = "zero",
5555
) -> Tensor:
5656
"""Forward computation to calculate atomic wise dispersion energy"""
57-
shift = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift is None else shift
57+
shift_pos = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift_pos is None else shift_pos
5858
pos_bohr = pos / d3_autoang # angstrom -> bohr
5959
if cell is None:
6060
cell_bohr: Optional[Tensor] = None
6161
else:
6262
cell_bohr = cell / d3_autoang # angstrom -> bohr
63-
shift_bohr = shift / d3_autoang # angstrom -> bohr
63+
shift_bohr = shift_pos / d3_autoang # angstrom -> bohr
6464
r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr)
6565

6666
# E_disp (n_graphs,): Energy in eV unit

torch_dftd/nn/dftd3_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ def calc_energy_batch(
7070
edge_index: Tensor,
7171
cell: Optional[Tensor] = None,
7272
pbc: Optional[Tensor] = None,
73-
shift: Optional[Tensor] = None,
73+
shift_pos: Optional[Tensor] = None,
7474
batch: Optional[Tensor] = None,
7575
batch_edge: Optional[Tensor] = None,
7676
damping: str = "zero",
7777
) -> Tensor:
7878
"""Forward computation to calculate atomic wise dispersion energy"""
79-
shift = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift is None else shift
79+
shift_pos = pos.new_zeros((edge_index.size()[1], 3, 3)) if shift_pos is None else shift_pos
8080
pos_bohr = pos / d3_autoang # angstrom -> bohr
8181
if cell is None:
8282
cell_bohr: Optional[Tensor] = None
8383
else:
8484
cell_bohr = cell / d3_autoang # angstrom -> bohr
85-
shift_bohr = shift / d3_autoang # angstrom -> bohr
85+
shift_bohr = shift_pos / d3_autoang # angstrom -> bohr
8686
r = calc_distances(pos_bohr, edge_index, cell_bohr, shift_bohr)
8787
# E_disp (n_graphs,): Energy in eV unit
8888
E_disp = d3_autoev * edisp(
@@ -98,7 +98,7 @@ def calc_energy_batch(
9898
cnthr=self.cnthr / Bohr,
9999
batch=batch,
100100
batch_edge=batch_edge,
101-
shift=shift_bohr,
101+
shift_pos=shift_bohr,
102102
damping=damping,
103103
cutoff_smoothing=self.cutoff_smoothing,
104104
bidirectional=self.bidirectional,

torch_dftd/torch_dftd3_calculator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def _preprocess_atoms(self, atoms: Atoms) -> Dict[str, Optional[Tensor]]:
103103
pbc = torch.tensor(atoms.pbc, device=self.device)
104104
edge_index, S = self._calc_edge_index(pos, cell, pbc)
105105
if cell is None:
106-
shift = S
106+
shift_pos = S
107107
else:
108-
shift = torch.mm(S, cell.detach())
109-
input_dicts = dict(pos=pos, Z=Z, cell=cell, pbc=pbc, edge_index=edge_index, shift=shift)
108+
shift_pos = torch.mm(S, cell.detach())
109+
input_dicts = dict(
110+
pos=pos, Z=Z, cell=cell, pbc=pbc, edge_index=edge_index, shift_pos=shift_pos
111+
)
110112
return input_dicts
111113

112114
def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
@@ -155,7 +157,6 @@ def batch_calculate(self, atoms_list=None, properties=["energy"], system_changes
155157
# Calculator.calculate(self, atoms, properties, system_changes)
156158
input_dicts_list = [self._preprocess_atoms(atoms) for atoms in atoms_list]
157159
# --- Make batch ---
158-
# pos=pos, Z=Z, cell=cell, pbc=pbc, edge_index=edge_index, shift=S
159160
n_nodes_list = [d["Z"].shape[0] for d in input_dicts_list]
160161
shift_index_array = torch.cumsum(torch.tensor([0] + n_nodes_list), dim=0)
161162
cell_batch = torch.stack(
@@ -171,7 +172,7 @@ def batch_calculate(self, atoms_list=None, properties=["energy"], system_changes
171172
pos=torch.cat([d["pos"] for d in input_dicts_list], dim=0), # (n_nodes,)
172173
cell=cell_batch, # (bs, 3, 3)
173174
pbc=torch.stack([d["pbc"] for d in input_dicts_list]), # (bs, 3)
174-
shift=torch.cat([d["shift"] for d in input_dicts_list], dim=0), # (n_nodes,)
175+
shift_pos=torch.cat([d["shift_pos"] for d in input_dicts_list], dim=0), # (n_nodes,)
175176
)
176177

177178
batch_dicts["edge_index"] = torch.cat(

0 commit comments

Comments
 (0)