-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab8891f
commit 0fafea7
Showing
25 changed files
with
2,442 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
########################################################################################### | ||
# Elementary tools for handling irreducible representations | ||
# Authors: Ilyes Batatia, Gregor Simm | ||
# This program is distributed under the MIT License (see MIT.md) | ||
########################################################################################### | ||
|
||
from typing import List, Tuple | ||
|
||
import torch | ||
from e3nn import o3 | ||
from e3nn.util.jit import compile_mode | ||
|
||
|
||
# Based on mir-group/nequip | ||
def tp_out_irreps_with_instructions( | ||
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps | ||
) -> Tuple[o3.Irreps, List]: | ||
trainable = True | ||
|
||
# Collect possible irreps and their instructions | ||
irreps_out_list: List[Tuple[int, o3.Irreps]] = [] | ||
instructions = [] | ||
for i, (mul, ir_in) in enumerate(irreps1): | ||
for j, (_, ir_edge) in enumerate(irreps2): | ||
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 | ||
if ir_out in target_irreps: | ||
k = len(irreps_out_list) # instruction index | ||
irreps_out_list.append((mul, ir_out)) | ||
instructions.append((i, j, k, "uvu", trainable)) | ||
|
||
# We sort the output irreps of the tensor product so that we can simplify them | ||
# when they are provided to the second o3.Linear | ||
irreps_out = o3.Irreps(irreps_out_list) | ||
irreps_out, permut, _ = irreps_out.sort() | ||
|
||
# Permute the output indexes of the instructions to match the sorted irreps: | ||
instructions = [ | ||
(i_in1, i_in2, permut[i_out], mode, train) | ||
for i_in1, i_in2, i_out, mode, train in instructions | ||
] | ||
|
||
instructions = sorted(instructions, key=lambda x: x[2]) | ||
|
||
return irreps_out, instructions | ||
|
||
|
||
def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: | ||
# Assuming simplified irreps | ||
irreps_mid = [] | ||
for _, ir_in in irreps: | ||
found = False | ||
|
||
for mul, ir_out in target_irreps: | ||
if ir_in == ir_out: | ||
irreps_mid.append((mul, ir_out)) | ||
found = True | ||
break | ||
|
||
if not found: | ||
raise RuntimeError(f"{ir_in} not in {target_irreps}") | ||
|
||
return o3.Irreps(irreps_mid) | ||
|
||
|
||
@compile_mode("script") | ||
class reshape_irreps(torch.nn.Module): | ||
def __init__(self, irreps: o3.Irreps) -> None: | ||
super().__init__() | ||
self.irreps = o3.Irreps(irreps) | ||
self.dims = [] | ||
self.muls = [] | ||
for mul, ir in self.irreps: | ||
d = ir.dim | ||
self.dims.append(d) | ||
self.muls.append(mul) | ||
|
||
def forward(self, tensor: torch.Tensor) -> torch.Tensor: | ||
ix = 0 | ||
out = [] | ||
batch, _ = tensor.shape | ||
for mul, d in zip(self.muls, self.dims): | ||
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] | ||
ix += mul * d | ||
field = field.reshape(batch, mul, d) | ||
out.append(field) | ||
return torch.cat(out, dim=-1) | ||
|
||
|
||
def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): | ||
out = [] | ||
for i in range(num_layers - 1): | ||
out.append( | ||
x[ | ||
:, | ||
i | ||
* (l_max + 1) ** 2 | ||
* num_features : (i * (l_max + 1) ** 2 + 1) | ||
* num_features, | ||
] | ||
) | ||
out.append(x[:, -num_features:]) | ||
return torch.cat(out, dim=-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
########################################################################################### | ||
# __init__ file for Modules | ||
# Authors: Ilyes Batatia, Gregor Simm | ||
# This program is distributed under the MIT License (see MIT.md) | ||
########################################################################################### | ||
# Taken From: | ||
# GitHub: https://github.com/ACEsuit/mace | ||
# ArXiV: https://arxiv.org/pdf/2206.07697 | ||
# Date: August 27, 2024 | 12:37 (EST) | ||
########################################################################################### | ||
|
||
from typing import Callable, Dict, Optional, Type | ||
|
||
import torch | ||
|
||
from .blocks import ( | ||
AtomicEnergiesBlock, | ||
EquivariantProductBasisBlock, | ||
InteractionBlock, | ||
LinearNodeEmbeddingBlock, | ||
LinearReadoutBlock, | ||
NonLinearReadoutBlock, | ||
RadialEmbeddingBlock, | ||
RealAgnosticAttResidualInteractionBlock, | ||
ScaleShiftBlock, | ||
) | ||
|
||
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff | ||
from .symmetric_contraction import SymmetricContraction | ||
|
||
interaction_classes: Dict[str, Type[InteractionBlock]] = { | ||
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, | ||
} | ||
|
||
__all__ = [ | ||
"AtomicEnergiesBlock", | ||
"RadialEmbeddingBlock", | ||
"LinearNodeEmbeddingBlock", | ||
"LinearReadoutBlock", | ||
"EquivariantProductBasisBlock", | ||
"ScaleShiftBlock", | ||
"LinearDipoleReadoutBlock", | ||
"NonLinearDipoleReadoutBlock", | ||
"InteractionBlock", | ||
"NonLinearReadoutBlock", | ||
"PolynomialCutoff", | ||
"BesselBasis", | ||
"GaussianBasis", | ||
"SymmetricContraction", | ||
"interaction_classes", | ||
] |
Oops, something went wrong.