Skip to content

Commit

Permalink
Reformatting MACE (ORNL#307)
Browse files Browse the repository at this point in the history
* reformatting MACE

* move irreps_string tool

* Comment

* adjusting position accounting for batch
  • Loading branch information
RylieWeaver committed Nov 12, 2024
1 parent 8a66236 commit ab04752
Show file tree
Hide file tree
Showing 4 changed files with 614 additions and 457 deletions.
5 changes: 5 additions & 0 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
from hydragnn.utils.distributed import get_device
from hydragnn.utils.print.print_utils import print_master
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths

import inspect

Expand Down Expand Up @@ -136,6 +137,10 @@ def _init_conv(self):
self.feature_layers.append(BatchNorm(self.hidden_dim))

def _embedding(self, data):
if not hasattr(data, "edge_shifts"):
data.edge_shifts = torch.zeros(
(data.edge_index.size(1), 3), device=data.edge_index.device
)
conv_args = {"edge_index": data.edge_index.to(torch.long)}
if self.use_edge_attr:
assert (
Expand Down
Loading

0 comments on commit ab04752

Please sign in to comment.