Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MACE #287

Merged
merged 54 commits into from
Oct 15, 2024
Merged

MACE #287

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
01b4ba5
First putting in MACE
RylieWeaver Aug 28, 2024
8ab9eb2
MACE Rebased
RylieWeaver Sep 11, 2024
d1fc744
test first MACE push
RylieWeaver Sep 25, 2024
eb46b76
formatting and clean-up
RylieWeaver Sep 25, 2024
e5ba82f
revise library downloads
RylieWeaver Sep 25, 2024
1c34bed
formatting and typo
RylieWeaver Sep 25, 2024
258bf78
installation change
RylieWeaver Sep 25, 2024
2396291
change versioning of torchmetrics for github platform. GitHub cannot …
RylieWeaver Sep 25, 2024
de269a1
GitHub only has access to some torchmetrics version
RylieWeaver Sep 25, 2024
3de2006
testing for issue with index url and torch-ema/torchmetrics
RylieWeaver Sep 25, 2024
a76fd4f
fix installs to account for index-url and formatting
RylieWeaver Sep 25, 2024
e9ed32c
formatting
RylieWeaver Sep 25, 2024
a9d06a2
try/except import to remove dependency
RylieWeaver Sep 25, 2024
5a2da6f
formatting
RylieWeaver Sep 25, 2024
4127d08
formatting
RylieWeaver Sep 25, 2024
f6c32cf
Add MACE to test
RylieWeaver Sep 25, 2024
ebae4d0
testing
RylieWeaver Sep 25, 2024
1039592
revert separate attempt to get test_forces in
RylieWeaver Sep 25, 2024
06c539a
revert separate attempt to get test_forces in
RylieWeaver Sep 25, 2024
d037fd3
commenting things that aren't needed for MACE in HYDRA (draft 1)
RylieWeaver Sep 26, 2024
61730a6
Making tests run faster and adjusting requirements
RylieWeaver Sep 26, 2024
d7652e3
need to install new requirements file
RylieWeaver Sep 26, 2024
e7a5d2f
formatting
RylieWeaver Sep 26, 2024
076aafa
formatting
RylieWeaver Sep 26, 2024
a12ce41
debugging for GitHub test
RylieWeaver Sep 26, 2024
f32d040
formatting
RylieWeaver Sep 26, 2024
3116416
Update config to avoid key error on input_dim
RylieWeaver Sep 26, 2024
dbc6ec9
Removing Unnecessary MACE files (draft 1)
RylieWeaver Sep 26, 2024
5bb5cb6
Commenting MACE utils in torch geometric (draft 2)
RylieWeaver Sep 27, 2024
f5148b8
formatting
RylieWeaver Sep 27, 2024
7f8129d
delete comments
RylieWeaver Sep 27, 2024
b2a9598
distributed, parsing, and checkpointing utils taken out (draft 3)
RylieWeaver Sep 27, 2024
6c89dd6
mace_utils comments (draft 4)
RylieWeaver Sep 27, 2024
af49a73
taking more mace utils out
RylieWeaver Sep 27, 2024
d4bef5e
rebase MACE and make the tests run a little faster
RylieWeaver Sep 27, 2024
a6ae103
torch scatter update
RylieWeaver Sep 30, 2024
b1e146d
Move around utils
RylieWeaver Sep 30, 2024
5ad74f3
clean up imports and move files
RylieWeaver Sep 30, 2024
e6df82f
formatting
RylieWeaver Sep 30, 2024
32fe8e5
formatting
RylieWeaver Sep 30, 2024
17f283b
Add source information
RylieWeaver Sep 30, 2024
92ddccd
Add checking and processing for node attributes
RylieWeaver Oct 1, 2024
173ae67
MACE natively oly handles atomic numbers as node_attributes. Add warn…
RylieWeaver Oct 1, 2024
6d53ff6
adjust requirements.txt installation and use hidden_dim for sizing mo…
RylieWeaver Oct 2, 2024
6034c0c
Add comments in compile file
RylieWeaver Oct 3, 2024
8b9d52b
tests for different radial transforms and exposing those options in MACE
RylieWeaver Oct 4, 2024
6f8db90
rebase and remove test comments
RylieWeaver Oct 8, 2024
81fbe11
Reverse tests changes fully
RylieWeaver Oct 8, 2024
6c21612
Missed reversed change
RylieWeaver Oct 8, 2024
0feb249
fix errors
RylieWeaver Oct 8, 2024
e910ef6
Fix edge attr usage
RylieWeaver Oct 8, 2024
35ee383
Fix Typo merge conflict
RylieWeaver Oct 8, 2024
a120e5e
fix bug from merge resolve
RylieWeaver Oct 8, 2024
cd978d3
merge in main fork changes
RylieWeaver Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade -r requirements.txt -r requirements-dev.txt
python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu
python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
python -m pip install --upgrade -r requirements-pyg.txt --find-links https://data.pyg.org/whl/torch-2.0.1+cpu.html
python -m pip install --upgrade -r requirements-deepspeed.txt
- name: Format black
Expand Down
2 changes: 2 additions & 0 deletions examples/LennardJones/LJ.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"num_before_skip": 1,
"num_after_skip": 1,
"envelope_exponent": 5,
"max_ell": 1,
"node_max_ell": 1,
"num_radial": 5,
"num_spherical": 2,
"hidden_dim": 20,
Expand Down
741 changes: 741 additions & 0 deletions hydragnn/models/MACEStack.py

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import torch
from torch_geometric.data import Data
from typing import List, Union

from hydragnn.models.GINStack import GINStack
from hydragnn.models.PNAStack import PNAStack
Expand All @@ -25,6 +26,7 @@
from hydragnn.models.EGCLStack import EGCLStack
from hydragnn.models.PNAEqStack import PNAEqStack
from hydragnn.models.PAINNStack import PAINNStack
from hydragnn.models.MACEStack import MACEStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
Expand Down Expand Up @@ -55,6 +57,8 @@ def create_model_config(
config["Architecture"]["num_before_skip"],
config["Architecture"]["num_after_skip"],
config["Architecture"]["num_radial"],
config["Architecture"]["radial_type"],
config["Architecture"]["distance_transform"],
config["Architecture"]["basis_emb_size"],
config["Architecture"]["int_emb_size"],
config["Architecture"]["out_emb_size"],
Expand All @@ -64,6 +68,10 @@ def create_model_config(
config["Architecture"]["num_filters"],
config["Architecture"]["radius"],
config["Architecture"]["equivariance"],
config["Architecture"]["correlation"],
config["Architecture"]["max_ell"],
config["Architecture"]["node_max_ell"],
config["Architecture"]["avg_num_neighbors"],
config["Training"]["conv_checkpointing"],
verbosity,
use_gpu,
Expand Down Expand Up @@ -91,6 +99,8 @@ def create_model(
num_before_skip: int = None,
num_after_skip: int = None,
num_radial: int = None,
radial_type: str = None,
distance_transform: str = None,
basis_emb_size: int = None,
int_emb_size: int = None,
out_emb_size: int = None,
Expand All @@ -100,6 +110,10 @@ def create_model(
num_filters: int = None,
radius: float = None,
equivariance: bool = False,
correlation: Union[int, List[int]] = None,
max_ell: int = None,
node_max_ell: int = None,
avg_num_neighbors: int = None,
conv_checkpointing: bool = False,
verbosity: int = 0,
use_gpu: bool = True,
Expand Down Expand Up @@ -371,6 +385,39 @@ def create_model(
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

elif model_type == "MACE":
assert radius is not None, "MACE requires radius input."
assert num_radial is not None, "MACE requires num_radial input."
assert max_ell is not None, "MACE requires max_ell input."
assert node_max_ell is not None, "MACE requires node_max_ell input."
assert max_ell >= 1, "MACE requires max_ell >= 1."
assert node_max_ell >= 1, "MACE requires node_max_ell >= 1."
model = MACEStack(
radius,
radial_type,
distance_transform,
num_radial,
edge_dim,
max_ell,
node_max_ell,
avg_num_neighbors,
envelope_exponent,
correlation,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
activation_function,
loss_function_type,
equivariance,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)
else:
raise ValueError("Unknown model_type: {0}".format(model_type))

Expand Down
31 changes: 27 additions & 4 deletions hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
check_if_graph_size_variable,
gather_deg,
)
from hydragnn.utils.model.model import calculate_avg_deg
from hydragnn.utils.distributed import get_comm_size_and_rank
from copy import deepcopy
import json
Expand Down Expand Up @@ -56,8 +57,22 @@ def update_config(config, train_loader, val_loader, test_loader):
else:
config["NeuralNetwork"]["Architecture"]["pna_deg"] = None

if config["NeuralNetwork"]["Architecture"]["model_type"] == "MACE":
if hasattr(train_loader.dataset, "avg_num_neighbors"):
## Use avg neighbours used in the dataset.
avg_num_neighbors = torch.tensor(train_loader.dataset.avg_num_neighbors)
else:
avg_num_neighbors = float(calculate_avg_deg(train_loader.dataset))
config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = avg_num_neighbors
else:
config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = None

if "radius" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["radius"] = None
if "radial_type" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["radial_type"] = None
if "distance_transform" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["distance_transform"] = None
if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None
if "num_filters" not in config["NeuralNetwork"]["Architecture"]:
Expand All @@ -78,6 +93,14 @@ def update_config(config, train_loader, val_loader, test_loader):
config["NeuralNetwork"]["Architecture"]["num_radial"] = None
if "num_spherical" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_spherical"] = None
if "radial_type" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["radial_type"] = None
if "correlation" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["correlation"] = None
if "max_ell" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["max_ell"] = None
if "node_max_ell" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["node_max_ell"] = None

config["NeuralNetwork"]["Architecture"] = update_config_edge_dim(
config["NeuralNetwork"]["Architecture"]
Expand Down Expand Up @@ -113,23 +136,23 @@ def update_config(config, train_loader, val_loader, test_loader):


def update_config_equivariance(config):
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN"]
equivariant_models = ["EGNN", "SchNet", "PNAEq", "PAINN", "MACE"]
if "equivariance" in config and config["equivariance"]:
assert (
config["model_type"] in equivariant_models
), "E(3) equivariance can only be ensured for EGNN and SchNet."
), "E(3) equivariance can only be ensured for EGNN, SchNet, and MACE."
elif "equivariance" not in config:
config["equivariance"] = False
return config


def update_config_edge_dim(config):
config["edge_dim"] = None
edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet"]
edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet", "MACE"]
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
), "Edge features can only be used with DimeNet EGNN, SchNet, PNA, PNAPlus, and CGCNN."
), "Edge features can only be used with DimeNet, MACE, EGNN, SchNet, PNA, PNAPlus, and CGCNN."
config["edge_dim"] = len(config["edge_features"])
elif config["model_type"] == "CGCNN":
# CG always needs an integer edge_dim
Expand Down
102 changes: 102 additions & 0 deletions hydragnn/utils/model/irreps_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
###########################################################################################
# Elementary tools for handling irreducible representations
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

from typing import List, Tuple

import torch
from e3nn import o3
from e3nn.util.jit import compile_mode


# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
trainable = True

# Collect possible irreps and their instructions
irreps_out_list: List[Tuple[int, o3.Irreps]] = []
instructions = []
for i, (mul, ir_in) in enumerate(irreps1):
for j, (_, ir_edge) in enumerate(irreps2):
for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
if ir_out in target_irreps:
k = len(irreps_out_list) # instruction index
irreps_out_list.append((mul, ir_out))
instructions.append((i, j, k, "uvu", trainable))

# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_out = o3.Irreps(irreps_out_list)
irreps_out, permut, _ = irreps_out.sort()

# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, permut[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]

instructions = sorted(instructions, key=lambda x: x[2])

return irreps_out, instructions


def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
# Assuming simplified irreps
irreps_mid = []
for _, ir_in in irreps:
found = False

for mul, ir_out in target_irreps:
if ir_in == ir_out:
irreps_mid.append((mul, ir_out))
found = True
break

if not found:
raise RuntimeError(f"{ir_in} not in {target_irreps}")

return o3.Irreps(irreps_mid)


@compile_mode("script")
class reshape_irreps(torch.nn.Module):
def __init__(self, irreps: o3.Irreps) -> None:
super().__init__()
self.irreps = o3.Irreps(irreps)
self.dims = []
self.muls = []
for mul, ir in self.irreps:
d = ir.dim
self.dims.append(d)
self.muls.append(mul)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
ix = 0
out = []
batch, _ = tensor.shape
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
field = field.reshape(batch, mul, d)
out.append(field)
return torch.cat(out, dim=-1)


def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int):
out = []
for i in range(num_layers - 1):
out.append(
x[
:,
i
* (l_max + 1) ** 2
* num_features : (i * (l_max + 1) ** 2 + 1)
* num_features,
]
)
out.append(x[:, -num_features:])
return torch.cat(out, dim=-1)
51 changes: 51 additions & 0 deletions hydragnn/utils/model/mace_utils/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
###########################################################################################
# __init__ file for Modules
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
# Taken From:
allaffa marked this conversation as resolved.
Show resolved Hide resolved
# GitHub: https://github.com/ACEsuit/mace
# ArXiV: https://arxiv.org/pdf/2206.07697
# Date: August 27, 2024 | 12:37 (EST)
###########################################################################################

from typing import Callable, Dict, Optional, Type

import torch

from .blocks import (
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
NonLinearReadoutBlock,
RadialEmbeddingBlock,
RealAgnosticAttResidualInteractionBlock,
ScaleShiftBlock,
)

from .radial import BesselBasis, GaussianBasis, PolynomialCutoff
from .symmetric_contraction import SymmetricContraction

interaction_classes: Dict[str, Type[InteractionBlock]] = {
"RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock,
}

__all__ = [
"AtomicEnergiesBlock",
"RadialEmbeddingBlock",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"EquivariantProductBasisBlock",
"ScaleShiftBlock",
"LinearDipoleReadoutBlock",
"NonLinearDipoleReadoutBlock",
"InteractionBlock",
"NonLinearReadoutBlock",
"PolynomialCutoff",
"BesselBasis",
"GaussianBasis",
"SymmetricContraction",
"interaction_classes",
]
Loading
Loading